Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
  • 10 commits
  • 17 files changed
  • 0 comments
  • 1 contributor

Showing 17 changed files with 235 additions and 364 deletions. Show diff stats Hide diff stats

  1. 126  storm-ml/src/jvm/storm/ml/Main.java
  2. 59  storm-ml/src/jvm/storm/ml/PerceptronTopologyBuilder.java
  3. 42  storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
  4. 23  storm-ml/src/main/java/com/twitter/algorithms/Learner.java
  5. 9  storm-ml/src/main/java/com/twitter/algorithms/LossFunction.java
  6. 2  storm-ml/src/main/java/com/twitter/data/Example.java
  7. 64  storm-ml/src/main/java/com/twitter/storm/example/MainOnlineTopology.java
  8. 6  storm-ml/src/main/java/com/twitter/{algorithms → storm/primitives}/Aggregator.java
  9. 4  storm-ml/src/main/java/com/twitter/storm/primitives/BaseTrainingSpout.java
  10. 11  storm-ml/src/main/java/com/twitter/storm/primitives/EvaluationBolt.java
  11. 33  storm-ml/src/main/java/com/twitter/storm/primitives/ExampleTrainingSpout.java
  12. 128  storm-ml/src/main/java/com/twitter/storm/primitives/MLTopologyBuilder.java
  13. 36  storm-ml/src/main/java/com/twitter/storm/primitives/TrainingSpout.java
  14. 27  storm-ml/src/main/java/com/twitter/storm/primitives/example/ExampleTrainingSpout.java
  15. 7  storm-ml/src/main/java/com/twitter/storm/primitives/{ → example}/LocalLearner.java
  16. 9  storm-ml/src/main/java/com/twitter/util/Datautil.java
  17. 13  storm-ml/src/main/java/com/twitter/util/datautil.clj
126  storm-ml/src/jvm/storm/ml/Main.java
... ...
@@ -1,126 +0,0 @@
1  
-package storm.ml;
2  
-
3  
-import java.lang.Boolean;
4  
-import java.math.BigDecimal;
5  
-import java.util.List;
6  
-import java.util.ArrayList;
7  
-
8  
-import backtype.storm.Config;
9  
-import backtype.storm.LocalCluster;
10  
-import backtype.storm.StormSubmitter;
11  
-import backtype.storm.task.OutputCollector;
12  
-import backtype.storm.task.TopologyContext;
13  
-import backtype.storm.testing.TestWordSpout;
14  
-import backtype.storm.topology.OutputFieldsDeclarer;
15  
-import backtype.storm.topology.TopologyBuilder;
16  
-import backtype.storm.topology.base.BaseRichBolt;
17  
-import backtype.storm.topology.base.BaseRichSpout;
18  
-import backtype.storm.tuple.Fields;
19  
-import backtype.storm.tuple.Tuple;
20  
-import backtype.storm.tuple.Values;
21  
-import backtype.storm.utils.Utils;
22  
-import java.util.Map;
23  
-
24  
-import org.javatuples.Pair;
25  
-
26  
-import storm.ml.PerceptronTopologyBuilder;
27  
-
28  
-public class Main {
29  
-    
30  
-    public static class TrainingSpout extends BaseRichSpout {
31  
-        OutputCollector _collector;
32  
-
33  
-        @Override
34  
-        public void prepare(Map conf, TopologyContext context, OutputCollector collector) {
35  
-            _collector = collector;
36  
-        }
37  
-
38  
-        @Override
39  
-        public void execute(Tuple tuple) {
40  
-            _t
41  
-            _collector.emit(tuple, new Values(tuple.getString(0) + "!!!"));
42  
-            _collector.ack(tuple);
43  
-        }
44  
-
45  
-        @Override
46  
-        public void declareOutputFields(OutputFieldsDeclarer declarer) {
47  
-            declarer.declare(new Fields("word"));
48  
-        }
49  
-
50  
-    }
51  
-
52  
-    public static class TrainingBolt extends BaseRichBolt {
53  
-        OutputCollector _collector;
54  
-
55  
-        @Override
56  
-        public void prepare(Map conf, TopologyContext context, OutputCollector collector) {
57  
-            _collector = collector;
58  
-        }
59  
-
60  
-        @Override
61  
-        public void execute(Tuple tuple) {
62  
-            _t
63  
-            _collector.emit(tuple, new Values(tuple.getString(0) + "!!!"));
64  
-            _collector.ack(tuple);
65  
-        }
66  
-
67  
-        @Override
68  
-        public void declareOutputFields(OutputFieldsDeclarer declarer) {
69  
-            declarer.declare(new Fields("word"));
70  
-        }
71  
-
72  
-    }
73  
-
74  
-    public perceptronTopology() {
75  
-        TopologyBuilder builder = new TopologyBuilder();
76  
-
77  
-        builder.setSpout("word", new TestWordSpout(), 10);
78  
-        builder.setBolt("exclaim1", new ExclamationBolt(), 3).shuffleGrouping("word");
79  
-        builder.setBolt("exclaim2", new ExclamationBolt(), 2).shuffleGrouping("exclaim1");
80  
-
81  
-        Config conf = new Config();
82  
-        conf.setDebug(true);
83  
-
84  
-        if (args != null && args.length > 0) {
85  
-            conf.setNumWorkers(3);
86  
-
87  
-            StormSubmitter.submitTopology(args[0], conf, builder.createTopology());
88  
-        } else {
89  
-
90  
-            LocalCluster cluster = new LocalCluster();
91  
-            cluster.submitTopology("test", conf, builder.createTopology());
92  
-            Utils.sleep(10000);
93  
-            cluster.killTopology("test");
94  
-            cluster.shutdown();
95  
-        }
96  
-    }
97  
-
98  
-    public static void main(String[] keywords) {
99  
-        perceptronTopology();
100  
-//        List<Pair<List<BigDecimal>, Boolean>> training_set = new ArrayList<Pair<List<BigDecimal>, Boolean>>(4);
101  
-//        List<BigDecimal> input_vector = new ArrayList<BigDecimal>(3);
102  
-//
103  
-//        input_vector.add(new BigDecimal(1));
104  
-//        input_vector.add(new BigDecimal(0));
105  
-//        input_vector.add(new BigDecimal(0));
106  
-//        training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
107  
-//
108  
-//        input_vector.add(new BigDecimal(1));
109  
-//        input_vector.add(new BigDecimal(0));
110  
-//        input_vector.add(new BigDecimal(1));
111  
-//        training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
112  
-//
113  
-//        input_vector.add(new BigDecimal(1));
114  
-//        input_vector.add(new BigDecimal(1));
115  
-//        input_vector.add(new BigDecimal(0));
116  
-//        training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
117  
-//
118  
-//        input_vector.add(new BigDecimal(1));
119  
-//        input_vector.add(new BigDecimal(1));
120  
-//        input_vector.add(new BigDecimal(1));
121  
-//        training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, false));
122  
-//
123  
-//        PerceptronTopologyBuilder ptb = new PerceptronTopologyBuilder(3, 0.5, 0.1);
124  
-//        ptb.train(training_set);
125  
-    }
126  
-}
59  storm-ml/src/jvm/storm/ml/PerceptronTopologyBuilder.java
... ...
@@ -1,59 +0,0 @@
1  
-package storm.ml;
2  
-
3  
-import java.lang.Integer;
4  
-import java.lang.Double;
5  
-import java.lang.Boolean;
6  
-import java.math.BigDecimal;
7  
-import java.util.List;
8  
-import java.util.ArrayList;
9  
-
10  
-import org.javatuples.Pair;
11  
-
12  
-public class PerceptronTopologyBuilder {
13  
-    public final Integer size;
14  
-    public final Double threshold;
15  
-    public final Double learning_rate;
16  
-
17  
-    privatez List<BigDecimal> weights;
18  
-
19  
-    public PerceptronTopologyBuilder(Integer size, Double threshold, Double learning_rate) {
20  
-        this.size = size;                   // size of the weight array and input
21  
-        this.threshold = threshold;         // margin to determine positive results
22  
-        this.learning_rate = learning_rate; // adaptation factor for the weights
23  
-
24  
-        this.weights = new ArrayList<BigDecimal>(size);
25  
-        int i; for (i=0; i<size; i++)
26  
-            this.weights.add(new BigDecimal(0));
27  
-    }
28  
-
29  
-    private BigDecimal dot_product(List<BigDecimal> vector1, List<BigDecimal> vector2) {
30  
-        BigDecimal result = new BigDecimal(0);
31  
-        int i; for (i=0; i<this.size; i++)
32  
-            result.add(vector1.get(i).multiply(vector2.get(i)));
33  
-
34  
-        return result;
35  
-    }
36  
-
37  
-    public void train(List<Pair<List<BigDecimal>, Boolean>> training_set) {
38  
-        while (true) {
39  
-            int error_count = 0;
40  
-            for (Pair<List<BigDecimal>, Boolean> training_pair : training_set) {
41  
-                List<BigDecimal> input_vector = training_pair.getValue0();
42  
-                Integer desired_output        = training_pair.getValue1() ? 1 : 0;
43  
-                System.out.println(String.format("%s", this.weights));
44  
-
45  
-                int result = dot_product(input_vector, this.weights).compareTo(new BigDecimal(threshold)) > 0 ? 1 : 0;
46  
-
47  
-                int error = desired_output - result;
48  
-                if (error != 0) {
49  
-                    error_count += 1;
50  
-                    int i; for (i=0; i<this.size; i++)
51  
-                        this.weights.set(i, this.weights.get(i).add(input_vector.get(i).multiply(new BigDecimal(this.learning_rate * error))));
52  
-                }
53  
-            }
54  
-            if (error_count == 0)
55  
-                break;
56  
-            System.out.println();
57  
-        }
58  
-    }
59  
-}
42  storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
... ...
@@ -1,42 +0,0 @@
1  
-package com.twitter;
2  
-
3  
-import backtype.storm.Config;
4  
-import backtype.storm.LocalCluster;
5  
-import backtype.storm.LocalDRPC;
6  
-import backtype.storm.drpc.LinearDRPCTopologyBuilder;
7  
-import backtype.storm.topology.TopologyBuilder;
8  
-import backtype.storm.utils.Utils;
9  
-
10  
-import com.twitter.algorithms.Aggregator;
11  
-import com.twitter.storm.primitives.EvaluationBolt;
12  
-import com.twitter.storm.primitives.LocalLearner;
13  
-import com.twitter.storm.primitives.TrainingSpout;
14  
-
15  
-public class MainOnlineTopology {
16  
-    public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
17  
-    static Double threshold = 0.5;
18  
-    static Double bias = 1.0;
19  
-
20  
-    public static void main(String[] args) throws Exception {
21  
-        TopologyBuilder builder = new TopologyBuilder();
22  
-        LocalDRPC drpc = new LocalDRPC();
23  
-
24  
-        builder.setSpout("example_spitter", new TrainingSpout());
25  
-        builder.setBolt("local_learner", new LocalLearner(2, MEMCACHED_SERVERS), 1).shuffleGrouping("example_spitter");
26  
-        builder.setBolt("aggregator", new Aggregator(MEMCACHED_SERVERS)).globalGrouping("local_learner");
27  
-
28  
-        LinearDRPCTopologyBuilder drpc_builder = new LinearDRPCTopologyBuilder("evaluate");
29  
-        drpc_builder.addBolt(new EvaluationBolt(bias, threshold, MEMCACHED_SERVERS), 3);
30  
-
31  
-        Config conf = new Config();
32  
-        conf.setDebug(true);
33  
-        LocalCluster cluster = new LocalCluster();
34  
-        cluster.submitTopology("learning", conf, builder.createTopology());
35  
-        // cluster.submitTopology("evaluation", conf, drpc_builder.createLocalTopology(drpc));
36  
-
37  
-        Utils.sleep(10000);
38  
-        cluster.killTopology("learning");
39  
-        cluster.shutdown();
40  
-
41  
-    }
42  
-}
23  storm-ml/src/main/java/com/twitter/algorithms/Learner.java
@@ -7,16 +7,17 @@
7 7
 
8 8
 import net.spy.memcached.MemcachedClient;
9 9
 
  10
+import org.apache.commons.lang.ArrayUtils;
10 11
 import org.apache.log4j.Logger;
11 12
 
12 13
 import com.twitter.data.Example;
13  
-import com.twitter.storm.primitives.LocalLearner;
14 14
 import com.twitter.util.Datautil;
15 15
 import com.twitter.util.MathUtil;
16 16
 
17 17
 public class Learner implements Serializable {
18  
-    public static Logger LOG = Logger.getLogger(LocalLearner.class);
  18
+    public static Logger LOG = Logger.getLogger(Learner.class);
19 19
 
  20
+    private double threshold = 0.0;
20 21
     protected double[] weights;
21 22
     protected LossFunction lossFunction;
22 23
     int numExamples = 0;
@@ -30,19 +31,27 @@ public Learner(int dimension) {
30 31
         lossFunction = new LossFunction(2);
31 32
     }
32 33
 
  34
+    public void setLocalWeights(List<Double> localWeights) {
  35
+        Double[] weights_double = localWeights.toArray(new Double[localWeights.size()]);
  36
+        this.setWeights(ArrayUtils.toPrimitive(weights_double));
  37
+
  38
+    }
  39
+
33 40
     public void update(Example example, int epoch, MemcachedClient memcache) {
34  
-        String cas_weights = (String) memcache.get("weights");
  41
+        String cas_weights = (String) memcache.get("model");
35 42
         List<Double> weights = Datautil.parse_str_vector(cas_weights);
36  
-        LOG.error("double weights" + weights);
  43
+        setLocalWeights(weights);
37 44
         int predicted = predict(example);
38 45
         updateStats(example, predicted);
39  
-        LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
  46
+        LOG.debug("EXAMPLE " + example.x[0] + "," + example.x[1] + " LABEL" + example.label + " PREDICTED: "
  47
+                + predicted);
40 48
         if (example.isLabeled) {
41 49
             if ((double) predicted != example.label) {
42 50
                 List<Double> gradient = lossFunction.gradient(example, predicted);
43 51
                 gradientSum += MathUtil.l2norm(gradient);
44 52
                 double eta = getLearningRate(example, epoch);
45  
-                MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta));
  53
+                LOG.debug("NEW WEIGHTS" + MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta)));
  54
+                setLocalWeights(MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta)));
46 55
             }
47 56
         }
48 57
         displayStats();
@@ -74,7 +83,7 @@ public void initWeights(double[] newWeights) {
74 83
 
75 84
     public int predict(Example example) {
76 85
         double dot = MathUtil.dot(weights, example.x);
77  
-        return (dot >= 0.0) ? 1 : -1;
  86
+        return (dot >= this.threshold) ? 1 : -1;
78 87
     }
79 88
 
80 89
     protected void updateStats(Example example, int prediction) {
9  storm-ml/src/main/java/com/twitter/algorithms/LossFunction.java
@@ -4,13 +4,15 @@
4 4
 import java.util.ArrayList;
5 5
 import java.util.List;
6 6
 
  7
+import org.apache.log4j.Logger;
  8
+
7 9
 import com.twitter.data.Example;
  10
+import com.twitter.storm.primitives.example.LocalLearner;
8 11
 
9 12
 public class LossFunction implements Serializable {
10  
-    private List<Double> grad; // gradient
  13
+    public static Logger LOG = Logger.getLogger(LocalLearner.class);
11 14
 
12 15
     public LossFunction(int dimension) {
13  
-        grad = new ArrayList<Double>();
14 16
     }
15 17
 
16 18
     public double get(Example e, int prediction) {
@@ -18,9 +20,10 @@ public double get(Example e, int prediction) {
18 20
     }
19 21
 
20 22
     public List<Double> gradient(Example e, int prediction) {
  23
+        List<Double> grad = new ArrayList<Double>();
21 24
         double f = -1.0 * (e.label - prediction);
22 25
         for (int i = 0; i < e.x.length; i++) {
23  
-            grad.set(i, f * e.x[i]);
  26
+            grad.add(f * e.x[0]);
24 27
         }
25 28
         return grad;
26 29
     }
2  storm-ml/src/main/java/com/twitter/data/Example.java
@@ -14,7 +14,7 @@
14 14
 
15 15
     public Example(int dimension) {
16 16
         x = new double[dimension];
17  
-        isLabeled = false;
  17
+        isLabeled = true;
18 18
     }
19 19
 
20 20
     public String toString() {
64  storm-ml/src/main/java/com/twitter/storm/example/MainOnlineTopology.java
... ...
@@ -0,0 +1,64 @@
  1
+package com.twitter.storm.example;
  2
+
  3
+import java.util.ArrayList;
  4
+import java.util.List;
  5
+
  6
+import net.spy.memcached.AddrUtil;
  7
+import net.spy.memcached.MemcachedClient;
  8
+import net.spy.memcached.internal.OperationFuture;
  9
+
  10
+import org.apache.log4j.Logger;
  11
+
  12
+import backtype.storm.Config;
  13
+import backtype.storm.LocalCluster;
  14
+import backtype.storm.LocalDRPC;
  15
+import backtype.storm.utils.Utils;
  16
+
  17
+import com.twitter.storm.primitives.EvaluationBolt;
  18
+import com.twitter.storm.primitives.MLTopologyBuilder;
  19
+import com.twitter.storm.primitives.example.ExampleTrainingSpout;
  20
+import com.twitter.storm.primitives.example.LocalLearner;
  21
+
  22
+public class MainOnlineTopology {
  23
+    public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
  24
+    public static Logger LOG = Logger.getLogger(MainOnlineTopology.class);
  25
+    static Double threshold = 0.5;
  26
+    static Double bias = 1.0;
  27
+
  28
+    public static void main(String[] args) throws Exception {
  29
+        MemcachedClient memcache = new MemcachedClient(AddrUtil.getAddresses(MEMCACHED_SERVERS));
  30
+        OperationFuture promise = memcache.set("model", 0, "[0.1, -0.1]");
  31
+        promise.get();
  32
+
  33
+        Config topology_conf = new Config();
  34
+        String topology_name;
  35
+        if (args == null || args.length == 0)
  36
+            topology_name = "perceptron";
  37
+        else
  38
+            topology_name = args[0];
  39
+
  40
+        MLTopologyBuilder ml_topology_builder = new MLTopologyBuilder(topology_name, MEMCACHED_SERVERS);
  41
+
  42
+        ml_topology_builder.setTrainingSpout(new ExampleTrainingSpout());
  43
+        ml_topology_builder.setTrainingBolt(new LocalLearner(2, MEMCACHED_SERVERS));
  44
+        ml_topology_builder.setEvaluationBolt(new EvaluationBolt(1.0, 0.0, MEMCACHED_SERVERS));
  45
+
  46
+        if (args == null || args.length == 0) {
  47
+            LocalDRPC drpc = new LocalDRPC();
  48
+            LocalCluster cluster = new LocalCluster();
  49
+
  50
+            cluster.submitTopology(topology_name, topology_conf,
  51
+                    ml_topology_builder.createLocalTopology("evaluate", drpc));
  52
+
  53
+            List<Double> testVector = new ArrayList<Double>();
  54
+            testVector.add(3.0);
  55
+            testVector.add(1.0);
  56
+            String result = drpc.execute("evaluate", testVector.toString());
  57
+            LOG.error("RESULT: " + result);
  58
+
  59
+            Utils.sleep(10000);
  60
+            cluster.killTopology("perceptron");
  61
+            cluster.shutdown();
  62
+        }
  63
+    }
  64
+}
6  .../main/java/com/twitter/algorithms/Aggregator.java → ...java/com/twitter/storm/primitives/Aggregator.java
... ...
@@ -1,4 +1,4 @@
1  
-package com.twitter.algorithms;
  1
+package com.twitter.storm.primitives;
2 2
 
3 3
 import java.io.IOException;
4 4
 import java.util.List;
@@ -15,6 +15,7 @@
15 15
 import backtype.storm.topology.base.BaseRichBolt;
16 16
 import backtype.storm.tuple.Tuple;
17 17
 
  18
+import com.twitter.util.Datautil;
18 19
 import com.twitter.util.MathUtil;
19 20
 
20 21
 public class Aggregator extends BaseRichBolt {
@@ -43,10 +44,9 @@ public void execute(Tuple tuple) {
43 44
             MathUtil.plus(aggregateWeights, weight);
44 45
         }
45 46
         totalUpdateWeight += parallelUpdateWeight;
46  
-        LOG.info("aggregate weights" + aggregateWeights);
47 47
         MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
48 48
         if (aggregateWeights != null) {
49  
-            memcache.set("model", 3600 * 24, aggregateWeights);
  49
+            memcache.set("model", 3600 * 24, Datautil.toStrVector(aggregateWeights));
50 50
         }
51 51
 
52 52
     }
4  storm-ml/src/main/java/com/twitter/storm/primitives/BaseTrainingSpout.java
@@ -9,13 +9,13 @@
9 9
 import backtype.storm.tuple.Fields;
10 10
 
11 11
 public abstract class BaseTrainingSpout extends BaseRichSpout {
12  
-    SpoutOutputCollector _collector;
  12
+    protected SpoutOutputCollector _collector;
13 13
 
14 14
     public void open(Map conf, TopologyContext context, SpoutOutputCollector collector) {
15 15
         this._collector = collector;
16 16
     }
17 17
 
18 18
     public void declareOutputFields(OutputFieldsDeclarer declarer) {
19  
-        declarer.declare(new Fields("example", "label"));
  19
+        declarer.declare(new Fields("example-x", "example-y", "label"));
20 20
     }
21 21
 }
11  storm-ml/src/main/java/com/twitter/storm/primitives/EvaluationBolt.java
@@ -38,19 +38,20 @@ public void prepare(Map stormConf, TopologyContext context) {
38 38
     }
39 39
 
40 40
     List<Double> get_latest_weights() {
41  
-        String weights = (String) this.memcache.get("weights");
  41
+        String weights = (String) this.memcache.get("model");
42 42
         return Datautil.parse_str_vector(weights);
43 43
     }
44 44
 
45 45
     public void execute(Tuple tuple, BasicOutputCollector collector) {
46  
-        List<Double> weights = get_latest_weights();
47  
-
48 46
         String input_str = tuple.getString(1);
  47
+
  48
+        List<Double> weights = get_latest_weights();
49 49
         List<Double> input = Datautil.parse_str_vector(input_str);
50 50
 
51  
-        Double result = Datautil.dot_product(input, weights) + bias;
  51
+        Double evaluation = Datautil.dot_product(input, weights) + this.bias;
  52
+        String result = evaluation > this.threshold ? "1" : "-1";
52 53
 
53  
-        collector.emit(new Values(tuple.getValue(0), result > this.threshold ? 1 : 0));
  54
+        collector.emit(new Values(tuple.getString(0), result));
54 55
     }
55 56
 
56 57
     public void declareOutputFields(OutputFieldsDeclarer declarer) {
33  storm-ml/src/main/java/com/twitter/storm/primitives/ExampleTrainingSpout.java
... ...
@@ -1,33 +0,0 @@
1  
-package com.twitter.storm.primitives;
2  
-
3  
-import java.util.ArrayList;
4  
-import java.util.List;
5  
-
6  
-import backtype.storm.tuple.Values;
7  
-
8  
-public class ExampleTrainingSpout extends BaseTrainingSpout {
9  
-    int samples_count = 0;
10  
-    int max_samples = 100;
11  
-
12  
-    public static int get_label(Double x, Double y) {
13  
-        // arbitrary expected output (for testing purposes)
14  
-        return (2 * x + 1 < y) ? 1 : 0;
15  
-    }
16  
-
17  
-    public void nextTuple() {
18  
-        if (this.samples_count < this.max_samples) {
19  
-            Double x = 100 * Math.random();
20  
-            Double y = 100 * Math.random();
21  
-
22  
-            List<Double> example = new ArrayList<Double>();
23  
-            example.add(x);
24  
-            example.add(y);
25  
-
26  
-            int label = ExampleTrainingSpout.get_label(x, y);
27  
-
28  
-            _collector.emit(new Values(example.toString(), label));
29  
-
30  
-            this.samples_count++;
31  
-        }
32  
-    }
33  
-}
128  storm-ml/src/main/java/com/twitter/storm/primitives/MLTopologyBuilder.java
@@ -4,64 +4,132 @@
4 4
 import backtype.storm.drpc.DRPCSpout;
5 5
 import backtype.storm.drpc.ReturnResults;
6 6
 import backtype.storm.generated.StormTopology;
  7
+import backtype.storm.topology.IBasicBolt;
  8
+import backtype.storm.topology.IRichBolt;
7 9
 import backtype.storm.topology.TopologyBuilder;
8  
-import backtype.storm.topology.base.BaseRichBolt;
9  
-import backtype.storm.topology.base.BaseRichSpout;
10 10
 
11 11
 public class MLTopologyBuilder {
12 12
 
13  
-    public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
14  
-    private BaseRichBolt trainingBolt;
15  
-    private BaseRichSpout trainingSpout;
  13
+    String topology_prefix;
  14
+    String memcached_servers;
16 15
 
17  
-    public TopologyBuilder prepareTopology(ILocalDRPC drpc) {
18  
-        return prepareTopology(drpc, 3.0, 0.0, 3.0, MEMCACHED_SERVERS);
  16
+    BaseTrainingSpout training_spout;
  17
+    Number training_spout_parallelism;
  18
+
  19
+    IBasicBolt basic_training_bolt;
  20
+    IRichBolt rich_training_bolt;
  21
+    Number training_bolt_parallelism;
  22
+
  23
+    IBasicBolt basic_evaluation_bolt;
  24
+    IRichBolt rich_evaluation_bolt;
  25
+    Number evaluation_bolt_parallelism;
  26
+
  27
+    public MLTopologyBuilder(String topologyPrefix, String memcached_servers) {
  28
+        this.memcached_servers = memcached_servers;
  29
+        this.topology_prefix = topologyPrefix;
  30
+    }
  31
+
  32
+    public TopologyBuilder prepareTopology(String drpcFunctionName, ILocalDRPC drpc) {
  33
+        return prepareTopology(drpcFunctionName, drpc, 1.0, 0.0, 0.5);
  34
+    }
  35
+
  36
+    public void setTrainingSpout(BaseTrainingSpout exampleTrainingSpout, Number parallelism) {
  37
+        this.training_spout = exampleTrainingSpout;
  38
+        this.training_spout_parallelism = parallelism;
  39
+    }
  40
+
  41
+    public void setTrainingSpout(BaseTrainingSpout exampleTrainingSpout) {
  42
+        setTrainingSpout(exampleTrainingSpout, 1);
  43
+    }
  44
+
  45
+    public void setTrainingBolt(IBasicBolt training_bolt, Number parallelism) {
  46
+        this.basic_training_bolt = training_bolt;
  47
+        this.rich_training_bolt = null;
  48
+        this.training_bolt_parallelism = parallelism;
  49
+    }
  50
+
  51
+    public void setTrainingBolt(IBasicBolt training_bolt) {
  52
+        setTrainingBolt(training_bolt, 1);
  53
+    }
  54
+
  55
+    public void setTrainingBolt(IRichBolt training_bolt, Number parallelism) {
  56
+        this.rich_training_bolt = training_bolt;
  57
+        this.basic_training_bolt = null;
  58
+        this.training_bolt_parallelism = parallelism;
  59
+    }
  60
+
  61
+    public void setTrainingBolt(IRichBolt training_bolt) {
  62
+        setTrainingBolt(training_bolt, 1);
  63
+    }
  64
+
  65
+    public void setEvaluationBolt(IBasicBolt evaluation_bolt, Number parallelism) {
  66
+        this.basic_evaluation_bolt = evaluation_bolt;
  67
+        this.rich_evaluation_bolt = null;
  68
+        this.evaluation_bolt_parallelism = parallelism;
  69
+    }
  70
+
  71
+    public void setEvaluationBolt(IBasicBolt evaluation_bolt) {
  72
+        setEvaluationBolt(evaluation_bolt, 1);
19 73
     }
20 74
 
21  
-    public void setTrainingBolt(BaseRichBolt trainingBolt) {
22  
-        this.trainingBolt = trainingBolt;
  75
+    public void setEvaluationBolt(IRichBolt evaluation_bolt, Number parallelism) {
  76
+        this.rich_evaluation_bolt = evaluation_bolt;
  77
+        this.basic_evaluation_bolt = null;
  78
+        this.evaluation_bolt_parallelism = parallelism;
23 79
     }
24 80
 
25  
-    public void setTrainingSpout(BaseRichSpout trainingSpout) {
26  
-        this.trainingSpout = trainingSpout;
  81
+    public void setEvaluationBolt(IRichBolt evaluation_bolt) {
  82
+        setEvaluationBolt(evaluation_bolt, 1);
27 83
     }
28 84
 
29  
-    public TopologyBuilder prepareTopology(ILocalDRPC drpc, double bias, double threshold, double learning_rate,
30  
-            String memcached_servers) {
  85
+    public TopologyBuilder prepareTopology(String drpcFunctionName, ILocalDRPC drpc, double bias, double threshold,
  86
+            double learning_rate) {
31 87
         TopologyBuilder topology_builder = new TopologyBuilder();
32 88
 
33 89
         // training
34  
-        topology_builder.setSpout("training-spout", new ExampleTrainingSpout());
35  
-
36  
-        topology_builder.setBolt("training-bolt", new LocalLearner(bias, threshold, learning_rate, MEMCACHED_SERVERS))
37  
-                .shuffleGrouping("training-spout");
  90
+        topology_builder.setSpout(this.topology_prefix + "-training-spout", this.training_spout,
  91
+                this.training_spout_parallelism);
  92
+
  93
+        if (this.rich_training_bolt == null) {
  94
+            topology_builder.setBolt(this.topology_prefix + "-training-bolt", this.basic_training_bolt,
  95
+                    this.training_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-training-spout");
  96
+        } else {
  97
+            topology_builder.setBolt(this.topology_prefix + "-training-bolt", this.rich_training_bolt,
  98
+                    this.training_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-training-spout");
  99
+        }
  100
+        topology_builder.setBolt("aggregator", new Aggregator(this.memcached_servers)).globalGrouping(
  101
+                this.topology_prefix + "-training-bolt");
38 102
 
39 103
         // evaluation
40 104
         DRPCSpout drpc_spout;
  105
+
41 106
         if (drpc != null)
42  
-            drpc_spout = new DRPCSpout("evaluate", drpc);
  107
+            drpc_spout = new DRPCSpout(drpcFunctionName, drpc);
43 108
         else
44  
-            drpc_spout = new DRPCSpout("evaluate");
45  
-
46  
-        topology_builder.setSpout("drpc-spout", drpc_spout);
  109
+            drpc_spout = new DRPCSpout(drpcFunctionName);
47 110
 
48  
-        topology_builder.setBolt(
49  
-                "drpc-evaluation",
50  
-                new EvaluationBolt(PerceptronDRPCTopology.bias, PerceptronDRPCTopology.threshold,
51  
-                        PerceptronDRPCTopology.MEMCACHED_SERVERS)).shuffleGrouping("drpc-spout");
  111
+        topology_builder.setSpout(this.topology_prefix + "-drpc-spout", drpc_spout);
52 112
 
53  
-        topology_builder.setBolt("drpc-return", new ReturnResults()).shuffleGrouping("drpc-evaluation");
  113
+        if (this.rich_evaluation_bolt == null) {
  114
+            topology_builder.setBolt(this.topology_prefix + "-drpc-evaluation", this.basic_evaluation_bolt,
  115
+                    this.evaluation_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-drpc-spout");
  116
+        } else {
  117
+            topology_builder.setBolt(this.topology_prefix + "-drpc-evaluation", this.rich_evaluation_bolt,
  118
+                    this.evaluation_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-drpc-spout");
  119
+        }
54 120
 
  121
+        topology_builder.setBolt(this.topology_prefix + "-drpc-return", new ReturnResults()).shuffleGrouping(
  122
+                this.topology_prefix + "-drpc-evaluation");
55 123
         // return
56 124
         return topology_builder;
57 125
 
58 126
     }
59 127
 
60  
-    public StormTopology createLocalTopology(ILocalDRPC drpc) {
61  
-        return prepareTopology(drpc).createTopology();
  128
+    public StormTopology createLocalTopology(String drpcFunctionName, ILocalDRPC drpc) {
  129
+        return prepareTopology(drpcFunctionName, drpc).createTopology();
62 130
     }
63 131
 
64  
-    public StormTopology createRemoteTopology() {
65  
-        return prepareTopology(null).createTopology();
  132
+    public StormTopology createRemoteTopology(String drpcFunctionName) {
  133
+        return prepareTopology(drpcFunctionName, null).createTopology();
66 134
     }
67 135
 }
36  storm-ml/src/main/java/com/twitter/storm/primitives/TrainingSpout.java
... ...
@@ -1,36 +0,0 @@
1  
-package com.twitter.storm.primitives;
2  
-
3  
-import java.util.List;
4  
-import java.util.Map;
5  
-
6  
-import backtype.storm.spout.SpoutOutputCollector;
7  
-import backtype.storm.task.TopologyContext;
8  
-import backtype.storm.topology.OutputFieldsDeclarer;
9  
-import backtype.storm.topology.base.BaseRichSpout;
10  
-import backtype.storm.tuple.Fields;
11  
-import backtype.storm.tuple.Values;
12  
-import backtype.storm.utils.Utils;
13  
-
14  
-import com.twitter.util.Datautil;
15  
-
16  
-public class TrainingSpout extends BaseRichSpout {
17  
-    SpoutOutputCollector _collector;
18  
-
19  
-    public void open(Map conf, TopologyContext context, SpoutOutputCollector collector) {
20  
-        _collector = collector;
21  
-    }
22  
-
23  
-    public void nextTuple() {
24  
-        Utils.sleep(100);
25  
-        List<Double[]> dataSet = new Datautil().readTrainingFile();
26  
-        for (Double[] trainingItem : dataSet) {
27  
-            _collector.emit(new Values(trainingItem));
28  
-        }
29  
-
30  
-    }
31  
-
32  
-    public void declareOutputFields(OutputFieldsDeclarer declarer) {
33  
-        declarer.declare(new Fields("trainingItem1", "t2", "t3"));
34  
-    }
35  
-
36  
-}
27  storm-ml/src/main/java/com/twitter/storm/primitives/example/ExampleTrainingSpout.java
... ...
@@ -0,0 +1,27 @@
  1
+package com.twitter.storm.primitives.example;
  2
+
  3
+import backtype.storm.tuple.Values;
  4
+
  5
+import com.twitter.storm.primitives.BaseTrainingSpout;
  6
+
  7
+public class ExampleTrainingSpout extends BaseTrainingSpout {
  8
+    int samples_count = 0;
  9
+    int max_samples = 100;
  10
+
  11
+    public static double get_label(Double x, Double y) {
  12
+        // arbitrary expected output (for testing purposes)
  13
+        return (2 * x + -1 > y) ? 1.0 : -1.0;
  14
+    }
  15
+
  16
+    public void nextTuple() {
  17
+        if (this.samples_count < this.max_samples) {
  18
+            Double x = 10 * Math.random();
  19
+            Double y = 5.0;
  20
+            double label = ExampleTrainingSpout.get_label(x, y);
  21
+
  22
+            _collector.emit(new Values(x, y, label));
  23
+
  24
+            this.samples_count++;
  25
+        }
  26
+    }
  27
+}
7  ...va/com/twitter/storm/primitives/LocalLearner.java → ...witter/storm/primitives/example/LocalLearner.java
... ...
@@ -1,4 +1,4 @@
1  
-package com.twitter.storm.primitives;
  1
+package com.twitter.storm.primitives.example;
2 2
 
3 3
 import java.io.IOException;
4 4
 import java.util.ArrayList;
@@ -48,8 +48,6 @@ public LocalLearner(int dimension, Learner onlinePerceptron, String memcached_se
48 48
             // this.hashFunction = hashAll;
49 49
             this.memcached_servers = memcached_servers;
50 50
             weightVector = new double[dimension];
51  
-            weightVector[0] = -6.8;
52  
-            weightVector[1] = -0.8;
53 51
             learner.setWeights(weightVector);
54 52
         } catch (Exception e) {
55 53
 
@@ -62,7 +60,8 @@ public void execute(Tuple tuple) {
62 60
         example.x[1] = (Double) tuple.getValue(1);
63 61
         example.label = (Double) tuple.getValue(2);
64 62
         learner.update(example, 1, memcache);
65  
-        LOG.debug("getwe" + learner.getWeights());
  63
+        LOG.debug("local weights" + learner.getWeightsArray() + " parallel weights "
  64
+                + learner.getParallelUpdateWeight());
66 65
         _collector.emit(new Values(learner.getWeightsArray(), learner.getParallelUpdateWeight()));
67 66
         _collector.ack(tuple);
68 67
     }
9  storm-ml/src/main/java/com/twitter/util/Datautil.java
@@ -54,4 +54,13 @@ public static Double dot_product(List<Double> vector_a, List<Double> vector_b) {
54 54
         }
55 55
         return lines;
56 56
     }
  57
+
  58
+    public static String toStrVector(List<Double> aggregateWeights) {
  59
+        String acc = "[";
  60
+        for (Double weight : aggregateWeights) {
  61
+            acc += weight.toString() + ", ";
  62
+        }
  63
+        acc += "]";
  64
+        return acc;
  65
+    }
57 66
 }
13  storm-ml/src/main/java/com/twitter/util/datautil.clj
... ...
@@ -1,13 +0,0 @@
1  
-(ns com.twitter.util.DataUtil
2  
-  (:gen-class))
3  
-
4  
-(defn parse-multiple-to-float
5  
-  ""
6  
-  [& args]
7  
-  (map #(Double/parseDouble %) args))
8  
-
9  
-(defn load-dataset
10  
-  []
11  
-  (let [data-text (map #(sutils/split % #"\t")
12  
-                       (sutils/split-lines (slurp "testSet.txt")))]
13  
-    (map #(apply parse-multiple-to-float %) data-text))

No commit comments for this range

Something went wrong with that request. Please try again.