forked from nathanmarz/storm-contrib
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2c9e41
commit 5b2ac55
Showing
23 changed files
with
762 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,10 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<classpath> | |||
<classpathentry kind="src" output="target/classes" path="src/main/java"/> | |||
<classpathentry excluding="**" kind="src" output="target/classes" path="src/main/clojure"/> | |||
<classpathentry kind="src" output="target/test-classes" path="src/test/java"/> | |||
<classpathentry excluding="**" kind="src" output="target/test-classes" path="src/test/resources"/> | |||
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5"/> | |||
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER"/> | |||
<classpathentry kind="output" path="target/classes"/> | |||
</classpath> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,36 @@ | |||
<?xml version="1.0" encoding="UTF-8"?> | |||
<projectDescription> | |||
<name>storm-ml2</name> | |||
<comment></comment> | |||
<projects> | |||
</projects> | |||
<buildSpec> | |||
<buildCommand> | |||
<name>ccw.builder</name> | |||
<arguments> | |||
</arguments> | |||
</buildCommand> | |||
<buildCommand> | |||
<name>org.eclipse.jdt.core.javabuilder</name> | |||
<arguments> | |||
</arguments> | |||
</buildCommand> | |||
<buildCommand> | |||
<name>org.eclipse.m2e.core.maven2Builder</name> | |||
<arguments> | |||
</arguments> | |||
</buildCommand> | |||
</buildSpec> | |||
<natures> | |||
<nature>org.eclipse.jdt.core.javanature</nature> | |||
<nature>org.eclipse.m2e.core.maven2Nature</nature> | |||
<nature>ccw.nature</nature> | |||
</natures> | |||
<linkedResources> | |||
<link> | |||
<name>clojure</name> | |||
<type>2</type> | |||
<location>/Users/fbrubacher/Documents/workspace/storm-ml2/src/main/clojure</location> | |||
</link> | |||
</linkedResources> | |||
</projectDescription> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,2 @@ | |||
eclipse.preferences.version=1 | |||
encoding/<project>=UTF-8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,5 @@ | |||
eclipse.preferences.version=1 | |||
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.5 | |||
org.eclipse.jdt.core.compiler.compliance=1.5 | |||
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning | |||
org.eclipse.jdt.core.compiler.source=1.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,4 @@ | |||
activeProfiles= | |||
eclipse.preferences.version=1 | |||
resolveWorkspaceProjects=true | |||
version=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,4 @@ | |||
|
|||
public class MainStorm { | |||
|
|||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,4 @@ | |||
|
|||
public class PerceptronTopology { | |||
|
|||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,16 @@ | |||
(ns com.twitter.Datautil | |||
(:require [clojure.string :as sstring]) | |||
(:gen-class)) | |||
|
|||
(defn parse-multiple-to-double | |||
"" | |||
[& args] | |||
(map #(Double/parseDouble %) args)) | |||
|
|||
(def load-dataset | |||
(let [data-text (map #(sstring/split % #"\t") | |||
(sstring/split-lines (slurp "testSet.txt")))] | |||
(map #(apply parse-multiple-to-double %) data-text))) | |||
|
|||
(def array-dataset | |||
(into-array (map (partial into-array Double/TYPE) load-dataset))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,13 @@ | |||
package com.twitter; | |||
|
|||
import java.io.IOException; | |||
|
|||
import com.twitter.util.MathUtil; | |||
|
|||
public class Main { | |||
|
|||
public static void main(String[] args) throws IOException { | |||
int dimension = MathUtil.nextLikelyPrime(10000); | |||
// Learner learner = new OnlinePerceptron(dimension); | |||
} | |||
} |
52 changes: 52 additions & 0 deletions
52
storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,52 @@ | |||
package com.twitter; | |||
|
|||
import java.io.File; | |||
import java.io.IOException; | |||
import java.util.ArrayList; | |||
import java.util.List; | |||
import java.util.Scanner; | |||
|
|||
import backtype.storm.Config; | |||
import backtype.storm.LocalCluster; | |||
import backtype.storm.topology.TopologyBuilder; | |||
import backtype.storm.tuple.Values; | |||
import backtype.storm.utils.Utils; | |||
|
|||
import com.twitter.storm.primitives.LocalLearner; | |||
import com.twitter.storm.primitives.TrainingSpout; | |||
import com.twitter.util.MathUtil; | |||
|
|||
public class MainOnlineTopology { | |||
|
|||
public static List<List<Object>> readExamples(String fileName) throws IOException { | |||
Scanner in = new Scanner(new File(fileName)); | |||
List<List<Object>> tupleList = new ArrayList<List<Object>>(); | |||
while (in.hasNext()) { | |||
String line = in.nextLine(); | |||
tupleList.add(new Values(line)); | |||
} | |||
in.close(); | |||
return tupleList; | |||
} | |||
|
|||
public static void main(String[] args) throws Exception { | |||
int dimension = MathUtil.nextLikelyPrime(10); | |||
System.out.println("Using dimension: " + dimension); | |||
|
|||
// Map exampleMap = new HashMap<Integer, List<List<Object>>>(); | |||
// exampleMap.put(0, readExamples(args[0])); | |||
|
|||
TopologyBuilder builder = new TopologyBuilder(); | |||
builder.setSpout("example_spitter", new TrainingSpout()); | |||
builder.setBolt("local_learner", new LocalLearner(2), 1).shuffleGrouping("example_spitter"); | |||
Config conf = new Config(); | |||
conf.setDebug(true); | |||
LocalCluster cluster = new LocalCluster(); | |||
cluster.submitTopology("test", conf, builder.createTopology()); | |||
Utils.sleep(10000); | |||
cluster.killTopology("test"); | |||
cluster.shutdown(); | |||
|
|||
// builder.setBolt("local_learner", new LocalLearner(dimension), 1).customGrouping(spout, grouping); | |||
} | |||
} |
61 changes: 61 additions & 0 deletions
61
storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,61 @@ | |||
package com.twitter.algorithms; | |||
|
|||
import java.util.Arrays; | |||
import java.util.Map; | |||
|
|||
import org.apache.log4j.Logger; | |||
|
|||
import backtype.storm.coordination.BatchOutputCollector; | |||
import backtype.storm.task.OutputCollector; | |||
import backtype.storm.task.TopologyContext; | |||
import backtype.storm.topology.OutputFieldsDeclarer; | |||
import backtype.storm.topology.base.BaseRichBolt; | |||
import backtype.storm.transactional.ICommitter; | |||
import backtype.storm.tuple.Tuple; | |||
|
|||
import com.twitter.util.MathUtil; | |||
|
|||
public class Aggregator extends BaseRichBolt implements ICommitter { | |||
|
|||
public static Logger LOG = Logger.getLogger(Aggregator.class); | |||
double[] aggregateWeights = null; | |||
double totalUpdateWeight = 0.0; | |||
|
|||
public void prepare(Map conf, TopologyContext context, BatchOutputCollector collector, Object id) { | |||
// TODO Auto-generated method stub | |||
|
|||
} | |||
|
|||
public void execute(Tuple tuple) { | |||
|
|||
double[] weight = (double[]) tuple.getValue(1); | |||
double parallelUpdateWeight = (Double) tuple.getValue(2); | |||
if (parallelUpdateWeight != 1.0) { | |||
weight = MathUtil.times(weight, parallelUpdateWeight); | |||
} | |||
if (aggregateWeights == null) { | |||
aggregateWeights = weight; | |||
} else { | |||
MathUtil.plus(aggregateWeights, weight); | |||
} | |||
totalUpdateWeight += parallelUpdateWeight; | |||
} | |||
|
|||
public void finishBatch() { | |||
if (aggregateWeights != null) { | |||
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight); | |||
LOG.info("New weight vector: " + Arrays.toString(aggregateWeights)); | |||
} | |||
} | |||
|
|||
public void declareOutputFields(OutputFieldsDeclarer declarer) { | |||
// TODO Auto-generated method stub | |||
|
|||
} | |||
|
|||
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) { | |||
// TODO Auto-generated method stub | |||
|
|||
} | |||
|
|||
} |
87 changes: 87 additions & 0 deletions
87
storm-ml/src/main/java/com/twitter/algorithms/Learner.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,87 @@ | |||
package com.twitter.algorithms; | |||
|
|||
import java.io.Serializable; | |||
import java.util.Arrays; | |||
|
|||
import org.apache.log4j.Logger; | |||
|
|||
import com.twitter.data.Example; | |||
import com.twitter.storm.primitives.LocalLearner; | |||
import com.twitter.util.MathUtil; | |||
|
|||
public class Learner implements Serializable { | |||
public static Logger LOG = Logger.getLogger(LocalLearner.class); | |||
|
|||
protected double[] weights; | |||
protected LossFunction lossFunction; | |||
int numExamples = 0; | |||
int numMisclassified = 0; | |||
double totalLoss = 0.0; | |||
double gradientSum = 0.0; | |||
protected double learningRate = 1.0; | |||
|
|||
public Learner(int dimension) { | |||
weights = new double[dimension]; | |||
lossFunction = new LossFunction(2); | |||
} | |||
|
|||
public void update(Example example, int epoch) { | |||
int predicted = predict(example); | |||
updateStats(example, predicted); | |||
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted); | |||
if (example.isLabeled) { | |||
if ((double) predicted != example.label) { | |||
double[] gradient = lossFunction.gradient(example, predicted); | |||
gradientSum += MathUtil.l2norm(gradient); | |||
double eta = getLearningRate(example, epoch); | |||
MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta)); | |||
} | |||
} | |||
displayStats(); | |||
} | |||
|
|||
protected double getLearningRate(Example example, int timestamp) { | |||
return learningRate / Math.sqrt(timestamp); | |||
} | |||
|
|||
public double[] getWeights() { | |||
return weights; | |||
} | |||
|
|||
public double getParallelUpdateWeight() { | |||
return gradientSum; | |||
} | |||
|
|||
public void initWeights(double[] newWeights) { | |||
assert (newWeights.length == weights.length); | |||
weights = Arrays.copyOf(newWeights, newWeights.length); | |||
} | |||
|
|||
public int predict(Example example) { | |||
double dot = MathUtil.dot(weights, example.x); | |||
return (dot >= 0.0) ? 1 : -1; | |||
} | |||
|
|||
protected void updateStats(Example example, int prediction) { | |||
numExamples++; | |||
if (example.label != prediction) | |||
numMisclassified++; | |||
totalLoss += lossFunction.get(example, prediction); | |||
} | |||
|
|||
public void displayStats() { | |||
if (numExamples == 0) { | |||
System.out.println("No examples seen so far."); | |||
} | |||
double accuracy = 1.0 - numMisclassified * 1.0 / numExamples; | |||
double meanLoss = totalLoss * 1.0 / numExamples; | |||
LOG.info(String.format("Accuracy: %g\tMean Loss: %g", accuracy, meanLoss)); | |||
|
|||
} | |||
|
|||
public void resetStats() { | |||
numExamples = 0; | |||
numMisclassified = 0; | |||
totalLoss = 0.0; | |||
} | |||
} |
30 changes: 30 additions & 0 deletions
30
storm-ml/src/main/java/com/twitter/algorithms/LossFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,30 @@ | |||
package com.twitter.algorithms; | |||
|
|||
import java.io.Serializable; | |||
|
|||
import com.twitter.data.Example; | |||
|
|||
public class LossFunction implements Serializable { | |||
private double[] grad; // gradient | |||
|
|||
public LossFunction(int dimension) { | |||
grad = new double[dimension]; | |||
} | |||
|
|||
public double get(Example e, int prediction) { | |||
return 0.5 * (e.label - prediction) * (e.label - prediction); | |||
} | |||
|
|||
public double[] gradient(Example e, int prediction) { | |||
double f = -1.0 * (e.label - prediction); | |||
for (int i = 0; i < e.x.length; i++) { | |||
grad[i] = f * e.x[i]; | |||
} | |||
return grad; | |||
} | |||
|
|||
static LossFunction byName(String name, int dimension) { | |||
return new LossFunction(dimension); | |||
} | |||
|
|||
} |
Oops, something went wrong.