Permalink
Browse files

Merge branch 'master' of http://github.com/fbrubacher/storm-contrib i…

…nto storm-ml
  • Loading branch information...
2 parents 826fbb3 + 63681e2 commit 34a18b42523885dea04593cb4c0ccd233c0aa736 @fbrubacher fbrubacher committed Aug 8, 2012
@@ -11,13 +11,13 @@
import org.apache.log4j.Logger;
import com.twitter.data.Example;
-import com.twitter.storm.primitives.example.LocalLearner;
import com.twitter.util.Datautil;
import com.twitter.util.MathUtil;
public class Learner implements Serializable {
- public static Logger LOG = Logger.getLogger(LocalLearner.class);
+ public static Logger LOG = Logger.getLogger(Learner.class);
+ private double threshold = 0.0;
protected double[] weights;
protected LossFunction lossFunction;
int numExamples = 0;
@@ -83,7 +83,7 @@ public void initWeights(double[] newWeights) {
public int predict(Example example) {
double dot = MathUtil.dot(weights, example.x);
- return (dot >= 0.0) ? 1 : -1;
+ return (dot >= this.threshold) ? 1 : -1;
}
protected void updateStats(Example example, int prediction) {
@@ -1,8 +1,14 @@
package com.twitter.storm.example;
+import java.util.ArrayList;
+import java.util.List;
+
import net.spy.memcached.AddrUtil;
import net.spy.memcached.MemcachedClient;
import net.spy.memcached.internal.OperationFuture;
+
+import org.apache.log4j.Logger;
+
import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.LocalDRPC;
@@ -15,6 +21,7 @@
public class MainOnlineTopology {
public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
+ public static Logger LOG = Logger.getLogger(MainOnlineTopology.class);
static Double threshold = 0.5;
static Double bias = 1.0;
@@ -34,7 +41,7 @@ public static void main(String[] args) throws Exception {
ml_topology_builder.setTrainingSpout(new ExampleTrainingSpout());
ml_topology_builder.setTrainingBolt(new LocalLearner(2, MEMCACHED_SERVERS));
- ml_topology_builder.setEvaluationBolt(new EvaluationBolt(1.0, 2.0, MEMCACHED_SERVERS));
+ ml_topology_builder.setEvaluationBolt(new EvaluationBolt(1.0, 0.0, MEMCACHED_SERVERS));
if (args == null || args.length == 0) {
LocalDRPC drpc = new LocalDRPC();
@@ -43,6 +50,12 @@ public static void main(String[] args) throws Exception {
cluster.submitTopology(topology_name, topology_conf,
ml_topology_builder.createLocalTopology("evaluate", drpc));
+ List<Double> testVector = new ArrayList<Double>();
+ testVector.add(3.0);
+ testVector.add(1.0);
+ String result = drpc.execute("evaluate", testVector.toString());
+ LOG.error("RESULT: " + result);
+
Utils.sleep(10000);
cluster.killTopology("perceptron");
cluster.shutdown();
@@ -35,7 +35,6 @@ public void execute(Tuple tuple) {
List<Double> weight = (List<Double>) tuple.getValue(0);
Double parallelUpdateWeight = (Double) tuple.getValue(1);
- LOG.error("AGGGG" + weight);
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
@@ -46,7 +45,6 @@ public void execute(Tuple tuple) {
}
totalUpdateWeight += parallelUpdateWeight;
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
- LOG.info("aggregate weights" + aggregateWeights);
if (aggregateWeights != null) {
memcache.set("model", 3600 * 24, Datautil.toStrVector(aggregateWeights));
}
@@ -38,19 +38,20 @@ public void prepare(Map stormConf, TopologyContext context) {
}
List<Double> get_latest_weights() {
- String weights = (String) this.memcache.get("weights");
+ String weights = (String) this.memcache.get("model");
return Datautil.parse_str_vector(weights);
}
public void execute(Tuple tuple, BasicOutputCollector collector) {
- List<Double> weights = get_latest_weights();
-
String input_str = tuple.getString(1);
+
+ List<Double> weights = get_latest_weights();
List<Double> input = Datautil.parse_str_vector(input_str);
- Double result = Datautil.dot_product(input, weights) + bias;
+ Double evaluation = Datautil.dot_product(input, weights) + this.bias;
+ String result = evaluation > this.threshold ? "1" : "-1";
- collector.emit(new Values(tuple.getValue(0), result > this.threshold ? 1 : 0));
+ collector.emit(new Values(tuple.getString(0), result));
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {

0 comments on commit 34a18b4

Please sign in to comment.