Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

integrate drpc with local learner

  • Loading branch information...
commit ccfb1b226b9bc35464426df1f6e62d8c7795710a 1 parent 2c58f6e
@fbrubacher authored
View
18 storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
@@ -2,25 +2,39 @@
import backtype.storm.Config;
import backtype.storm.LocalCluster;
+import backtype.storm.LocalDRPC;
+import backtype.storm.drpc.LinearDRPCTopologyBuilder;
import backtype.storm.topology.TopologyBuilder;
import backtype.storm.utils.Utils;
import com.twitter.algorithms.Aggregator;
+import com.twitter.storm.primitives.EvaluationBolt;
import com.twitter.storm.primitives.LocalLearner;
import com.twitter.storm.primitives.TrainingSpout;
public class MainOnlineTopology {
+ public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
+ static Double threshold = 0.5;
+ static Double bias = 1.0;
public static void main(String[] args) throws Exception {
TopologyBuilder builder = new TopologyBuilder();
+ LocalDRPC drpc = new LocalDRPC();
+
builder.setSpout("example_spitter", new TrainingSpout());
- builder.setBolt("local_learner", new LocalLearner(2), 1).shuffleGrouping("example_spitter");
+ builder.setBolt("local_learner", new LocalLearner(2, MEMCACHED_SERVERS), 1).shuffleGrouping("example_spitter");
builder.setBolt("aggregator", new Aggregator()).globalGrouping("local_learner");
+
+ LinearDRPCTopologyBuilder drpc_builder = new LinearDRPCTopologyBuilder("evaluate");
+ drpc_builder.addBolt(new EvaluationBolt(bias, threshold, MEMCACHED_SERVERS), 3);
+
Config conf = new Config();
conf.setDebug(true);
LocalCluster cluster = new LocalCluster();
- cluster.submitTopology("test", conf, builder.createTopology());
+ cluster.submitTopology("learning", conf, builder.createTopology());
+ cluster.submitTopology("evaluation", conf, drpc_builder.createLocalTopology(drpc));
+
Utils.sleep(10000);
cluster.killTopology("test");
cluster.shutdown();
View
8 storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
@@ -1,6 +1,6 @@
package com.twitter.algorithms;
-import java.util.Arrays;
+import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
@@ -16,12 +16,12 @@
public class Aggregator extends BaseRichBolt {
public static Logger LOG = Logger.getLogger(Aggregator.class);
- double[] aggregateWeights = null;
+ List<Double> aggregateWeights = null;
double totalUpdateWeight = 1.0;
public void execute(Tuple tuple) {
- double[] weight = (double[]) tuple.getValue(0);
+ List<Double> weight = (List<Double>) tuple.getValue(0);
Double parallelUpdateWeight = (Double) tuple.getValue(1);
if (parallelUpdateWeight != 1.0) {
@@ -37,7 +37,7 @@ public void execute(Tuple tuple) {
LOG.info(totalUpdateWeight);
if (aggregateWeights != null) {
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
- LOG.info("New AGGREGATE vector: " + Arrays.toString(aggregateWeights));
+ // LOG.info("New AGGREGATE vector: " + Arrays.toString(aggregateWeights));
}
}
View
12 storm-ml/src/main/java/com/twitter/algorithms/Learner.java
@@ -5,10 +5,14 @@
import java.util.Arrays;
import java.util.List;
+import net.spy.memcached.CASValue;
+import net.spy.memcached.MemcachedClient;
+
import org.apache.log4j.Logger;
import com.twitter.data.Example;
import com.twitter.storm.primitives.LocalLearner;
+import com.twitter.util.Datautil;
import com.twitter.util.MathUtil;
public class Learner implements Serializable {
@@ -21,19 +25,23 @@
double totalLoss = 0.0;
double gradientSum = 0.0;
protected double learningRate = 0.0;
+ MemcachedClient memcache;
- public Learner(int dimension) {
+ public Learner(int dimension, MemcachedClient memcache) {
weights = new double[dimension];
lossFunction = new LossFunction(2);
+ this.memcache = memcache;
}
public void update(Example example, int epoch) {
+ CASValue<Object> cas_weights = (CASValue<Object>) this.memcache.get("weights");
+ List<Double> weights = Datautil.parse_str_vector((String) cas_weights.getValue());
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
if (example.isLabeled) {
if ((double) predicted != example.label) {
- double[] gradient = lossFunction.gradient(example, predicted);
+ List<Double> gradient = lossFunction.gradient(example, predicted);
gradientSum += MathUtil.l2norm(gradient);
double eta = getLearningRate(example, epoch);
MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta));
View
10 storm-ml/src/main/java/com/twitter/algorithms/LossFunction.java
@@ -1,24 +1,26 @@
package com.twitter.algorithms;
import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
import com.twitter.data.Example;
public class LossFunction implements Serializable {
- private double[] grad; // gradient
+ private List<Double> grad; // gradient
public LossFunction(int dimension) {
- grad = new double[dimension];
+ grad = new ArrayList<Double>();
}
public double get(Example e, int prediction) {
return 0.5 * (e.label - prediction) * (e.label - prediction);
}
- public double[] gradient(Example e, int prediction) {
+ public List<Double> gradient(Example e, int prediction) {
double f = -1.0 * (e.label - prediction);
for (int i = 0; i < e.x.length; i++) {
- grad[i] = f * e.x[i];
+ grad.set(i, f * e.x[i]);
}
return grad;
}
View
59 storm-ml/src/main/java/com/twitter/storm/primitives/EvaluationBolt.java
@@ -0,0 +1,59 @@
+package com.twitter.storm.primitives;
+
+import java.util.List;
+import java.util.Map;
+
+import net.spy.memcached.AddrUtil;
+import net.spy.memcached.MemcachedClient;
+import backtype.storm.task.TopologyContext;
+import backtype.storm.topology.BasicOutputCollector;
+import backtype.storm.topology.OutputFieldsDeclarer;
+import backtype.storm.topology.base.BaseBasicBolt;
+import backtype.storm.tuple.Fields;
+import backtype.storm.tuple.Tuple;
+import backtype.storm.tuple.Values;
+
+import com.twitter.util.Datautil;
+
+public class EvaluationBolt extends BaseBasicBolt {
+ Double bias;
+ Double threshold;
+ String memcached_servers;
+ MemcachedClient memcache;
+
+ public EvaluationBolt(Double bias, Double threshold, String memcached_servers) {
+ this.threshold = threshold;
+ this.bias = bias;
+ this.memcached_servers = memcached_servers;
+ }
+
+ @Override
+ public void prepare(Map stormConf, TopologyContext context) {
+ super.prepare(stormConf, context);
+ try {
+ this.memcache = new MemcachedClient(AddrUtil.getAddresses(this.memcached_servers));
+ } catch (java.io.IOException e) {
+ System.exit(1);
+ }
+ }
+
+ List<Double> get_latest_weights() {
+ String weights = (String) this.memcache.get("weights");
+ return Datautil.parse_str_vector(weights);
+ }
+
+ public void execute(Tuple tuple, BasicOutputCollector collector) {
+ List<Double> weights = get_latest_weights();
+
+ String input_str = tuple.getString(1);
+ List<Double> input = Datautil.parse_str_vector(input_str);
+
+ Double result = Datautil.dot_product(input, weights) + bias;
+
+ collector.emit(new Values(tuple.getValue(0), result > this.threshold ? 1 : 0));
+ }
+
+ public void declareOutputFields(OutputFieldsDeclarer declarer) {
+ declarer.declare(new Fields("id", "result"));
+ }
+}
View
35 storm-ml/src/main/java/com/twitter/storm/primitives/LocalLearner.java
@@ -1,10 +1,14 @@
package com.twitter.storm.primitives;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import net.spy.memcached.AddrUtil;
+import net.spy.memcached.MemcachedClient;
+
import org.apache.log4j.Logger;
import backtype.storm.task.OutputCollector;
@@ -30,35 +34,36 @@
HashAll hashFunction;
Learner learner;
double[] weightVector;
+ MemcachedClient memcache;
- public LocalLearner(int dimension) {
- this(dimension, new Learner(dimension));// , new HashAll());
+ public LocalLearner(int dimension, String memcached_servers) throws IOException {
+ this(dimension, new Learner(dimension, new MemcachedClient(AddrUtil.getAddresses(memcached_servers))));
}
public LocalLearner(int dimension, Learner onlinePerceptron) {// , HashAll hashAll) {
- this.dimension = dimension;
- this.learner = onlinePerceptron;
- // this.hashFunction = hashAll;
- weightVector = new double[dimension];
- weightVector = new double[dimension];
- weightVector[0] = -6.8;
- weightVector[1] = -0.8;
- learner.setWeights(weightVector);
+ try {
+ this.dimension = dimension;
+ this.learner = onlinePerceptron;
+ // this.hashFunction = hashAll;
+
+ weightVector = new double[dimension];
+ weightVector = new double[dimension];
+ weightVector[0] = -6.8;
+ weightVector[1] = -0.8;
+ learner.setWeights(weightVector);
+ } catch (Exception e) {
+
+ }
}
public void execute(Tuple tuple) {
- LOG.debug("Old weights" + Arrays.toString(learner.getWeights()));
Example example = new Example(2);
example.x[0] = (Double) tuple.getValue(0);
example.x[1] = (Double) tuple.getValue(1);
example.label = (Double) tuple.getValue(2);
- example.isLabeled = true;
learner.update(example, 1);
_collector.emit(Arrays.asList((Object) learner.getWeights(), (Object) learner.getParallelUpdateWeight()));
_collector.ack(tuple);
- LOG.debug("New weights" + Arrays.toString(learner.getWeights()));
- // example.parseFrom((String) tuple.getValue(1), hashFunction);
- // buffer.add(example);
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
View
22 storm-ml/src/main/java/com/twitter/util/Datautil.java
@@ -6,9 +6,31 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.Scanner;
public class Datautil {
+ public static List<Double> parse_str_vector(String str_vector) {
+ List<Double> vector = new ArrayList<Double>();
+
+ Scanner scanner = new Scanner(str_vector.substring(1, str_vector.length() - 1));
+ scanner.useDelimiter(", ");
+
+ while (scanner.hasNextDouble())
+ vector.add(scanner.nextDouble());
+
+ return vector;
+ }
+
+ public static Double dot_product(List<Double> vector_a, List<Double> vector_b) {
+ Double result = 0.0;
+
+ for (int i = 0; i < vector_a.size(); i++)
+ result += vector_a.get(i) * vector_b.get(i);
+
+ return result;
+ }
+
public List<Double[]> readTrainingFile() {
List<Double[]> lines = new ArrayList<Double[]>();
String strLine;
View
55 storm-ml/src/main/java/com/twitter/util/MathUtil.java
@@ -1,70 +1,69 @@
package com.twitter.util;
import java.math.BigInteger;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
-/**
- * Misc. math util functions
- * (refactor with Twitter specific ones)
- * @author Delip Rao
- */
public class MathUtil {
- public static double l2norm(double [] v) {
+ public static double l2norm(List<Double> gradient) {
double sum = 0;
- for (double d : v) {
- sum += d*d;
+ for (double d : gradient) {
+ sum += d * d;
}
return sum;
}
- public static double [] zero(double [] v) {
+ public static double[] zero(double[] v) {
for (int i = 0; i < v.length; i++) {
v[i] = 0;
}
return v;
}
- public static double [] times(double [] v, double factor) {
- for (int i = 0; i < v.length; i++) {
- v[i] *= factor;
+ public static List<Double> times(List<Double> weights, double factor) {
+ for (Double weight : weights) {
+ weight *= factor;
}
- return v;
+ return weights;
}
- public static double [] timesC(double [] v, double factor) {
- double [] vc = Arrays.copyOf(v, v.length);
+ public static double[] timesC(double[] v, double factor) {
+ double[] vc = Arrays.copyOf(v, v.length);
for (int i = 0; i < v.length; i++) {
vc[i] *= factor;
}
return vc;
}
- public static double [] plus(double [] v, double [] u) {
- for (int i = 0; i < v.length; i++) {
- v[i] += u[i];
+ public static List<Double> plus(List<Double> weights, double[] u) {
+ for (int i = 0; i < u.length; i++) {
+ Double weight = weights.get(i);
+ weight += u[i];
+ weights.set(i, weight);
}
- return v;
+ return weights;
}
- public static double [] minus(double [] v, double [] u) {
+ public static double[] minus(double[] v, double[] u) {
for (int i = 0; i < v.length; i++) {
v[i] -= u[i];
}
return v;
}
- public static double [] minusC(double [] v, double [] u) {
- double [] vc = Arrays.copyOf(v, v.length);
+ public static double[] minusC(double[] v, double[] u) {
+ double[] vc = Arrays.copyOf(v, v.length);
for (int i = 0; i < v.length; i++) {
vc[i] -= u[i];
}
return vc;
}
- public static double dot(double [] u, double [] v) {
+ public static double dot(double[] u, double[] v) {
double result = 0;
for (int i = 0; i < v.length; i++) {
- result += u[i]*v[i];
+ result += u[i] * v[i];
}
return result;
}
@@ -73,4 +72,12 @@ public static int nextLikelyPrime(int n) {
String s = String.valueOf(n - 1);
return new BigInteger(s).nextProbablePrime().intValue();
}
+
+ public static List<Double> plus(List<Double> aggregateWeights, List<Double> weight) {
+ List<Double> result = new ArrayList<Double>();
+ for (int i = 0; i < aggregateWeights.size(); i++) {
+ result.add(aggregateWeights.get(i) * weight.get(i));
+ }
+ return result;
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.