Skip to content

Commit

Permalink
Added new project structure
Browse files Browse the repository at this point in the history
  • Loading branch information
fbrubacher committed Jul 3, 2012
1 parent b2c9e41 commit 5b2ac55
Show file tree
Hide file tree
Showing 23 changed files with 762 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
10 changes: 10 additions & 0 deletions storm-ml/.classpath
@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" output="target/classes" path="src/main/java"/>
<classpathentry excluding="**" kind="src" output="target/classes" path="src/main/clojure"/>
<classpathentry kind="src" output="target/test-classes" path="src/test/java"/>
<classpathentry excluding="**" kind="src" output="target/test-classes" path="src/test/resources"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5"/>
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER"/>
<classpathentry kind="output" path="target/classes"/>
</classpath>
36 changes: 36 additions & 0 deletions storm-ml/.project
@@ -0,0 +1,36 @@
<?xml version="1.0" encoding="UTF-8"?>
<projectDescription>
<name>storm-ml2</name>
<comment></comment>
<projects>
</projects>
<buildSpec>
<buildCommand>
<name>ccw.builder</name>
<arguments>
</arguments>
</buildCommand>
<buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
<buildCommand>
<name>org.eclipse.m2e.core.maven2Builder</name>
<arguments>
</arguments>
</buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.jdt.core.javanature</nature>
<nature>org.eclipse.m2e.core.maven2Nature</nature>
<nature>ccw.nature</nature>
</natures>
<linkedResources>
<link>
<name>clojure</name>
<type>2</type>
<location>/Users/fbrubacher/Documents/workspace/storm-ml2/src/main/clojure</location>
</link>
</linkedResources>
</projectDescription>
2 changes: 2 additions & 0 deletions storm-ml/.settings/org.eclipse.core.resources.prefs
@@ -0,0 +1,2 @@
eclipse.preferences.version=1
encoding/<project>=UTF-8
5 changes: 5 additions & 0 deletions storm-ml/.settings/org.eclipse.jdt.core.prefs
@@ -0,0 +1,5 @@
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.5
org.eclipse.jdt.core.compiler.compliance=1.5
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.source=1.5
4 changes: 4 additions & 0 deletions storm-ml/.settings/org.eclipse.m2e.core.prefs
@@ -0,0 +1,4 @@
activeProfiles=
eclipse.preferences.version=1
resolveWorkspaceProjects=true
version=1
4 changes: 4 additions & 0 deletions storm-ml/MainStorm.java
@@ -0,0 +1,4 @@

public class MainStorm {

}
4 changes: 4 additions & 0 deletions storm-ml/PerceptronTopology.java
@@ -0,0 +1,4 @@

public class PerceptronTopology {

}
16 changes: 16 additions & 0 deletions storm-ml/src/main/clojure/com/twitter/util/datautil.clj
@@ -0,0 +1,16 @@
(ns com.twitter.Datautil
(:require [clojure.string :as sstring])
(:gen-class))

(defn parse-multiple-to-double
""
[& args]
(map #(Double/parseDouble %) args))

(def load-dataset
(let [data-text (map #(sstring/split % #"\t")
(sstring/split-lines (slurp "testSet.txt")))]
(map #(apply parse-multiple-to-double %) data-text)))

(def array-dataset
(into-array (map (partial into-array Double/TYPE) load-dataset)))
13 changes: 13 additions & 0 deletions storm-ml/src/main/java/com/twitter/Main.java
@@ -0,0 +1,13 @@
package com.twitter;

import java.io.IOException;

import com.twitter.util.MathUtil;

public class Main {

public static void main(String[] args) throws IOException {
int dimension = MathUtil.nextLikelyPrime(10000);
// Learner learner = new OnlinePerceptron(dimension);
}
}
52 changes: 52 additions & 0 deletions storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
@@ -0,0 +1,52 @@
package com.twitter;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.topology.TopologyBuilder;
import backtype.storm.tuple.Values;
import backtype.storm.utils.Utils;

import com.twitter.storm.primitives.LocalLearner;
import com.twitter.storm.primitives.TrainingSpout;
import com.twitter.util.MathUtil;

public class MainOnlineTopology {

public static List<List<Object>> readExamples(String fileName) throws IOException {
Scanner in = new Scanner(new File(fileName));
List<List<Object>> tupleList = new ArrayList<List<Object>>();
while (in.hasNext()) {
String line = in.nextLine();
tupleList.add(new Values(line));
}
in.close();
return tupleList;
}

public static void main(String[] args) throws Exception {
int dimension = MathUtil.nextLikelyPrime(10);
System.out.println("Using dimension: " + dimension);

// Map exampleMap = new HashMap<Integer, List<List<Object>>>();
// exampleMap.put(0, readExamples(args[0]));

TopologyBuilder builder = new TopologyBuilder();
builder.setSpout("example_spitter", new TrainingSpout());
builder.setBolt("local_learner", new LocalLearner(2), 1).shuffleGrouping("example_spitter");
Config conf = new Config();
conf.setDebug(true);
LocalCluster cluster = new LocalCluster();
cluster.submitTopology("test", conf, builder.createTopology());
Utils.sleep(10000);
cluster.killTopology("test");
cluster.shutdown();

// builder.setBolt("local_learner", new LocalLearner(dimension), 1).customGrouping(spout, grouping);
}
}
61 changes: 61 additions & 0 deletions storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
@@ -0,0 +1,61 @@
package com.twitter.algorithms;

import java.util.Arrays;
import java.util.Map;

import org.apache.log4j.Logger;

import backtype.storm.coordination.BatchOutputCollector;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.transactional.ICommitter;
import backtype.storm.tuple.Tuple;

import com.twitter.util.MathUtil;

public class Aggregator extends BaseRichBolt implements ICommitter {

public static Logger LOG = Logger.getLogger(Aggregator.class);
double[] aggregateWeights = null;
double totalUpdateWeight = 0.0;

public void prepare(Map conf, TopologyContext context, BatchOutputCollector collector, Object id) {
// TODO Auto-generated method stub

}

public void execute(Tuple tuple) {

double[] weight = (double[]) tuple.getValue(1);
double parallelUpdateWeight = (Double) tuple.getValue(2);
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
if (aggregateWeights == null) {
aggregateWeights = weight;
} else {
MathUtil.plus(aggregateWeights, weight);
}
totalUpdateWeight += parallelUpdateWeight;
}

public void finishBatch() {
if (aggregateWeights != null) {
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
LOG.info("New weight vector: " + Arrays.toString(aggregateWeights));
}
}

public void declareOutputFields(OutputFieldsDeclarer declarer) {
// TODO Auto-generated method stub

}

public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
// TODO Auto-generated method stub

}

}
87 changes: 87 additions & 0 deletions storm-ml/src/main/java/com/twitter/algorithms/Learner.java
@@ -0,0 +1,87 @@
package com.twitter.algorithms;

import java.io.Serializable;
import java.util.Arrays;

import org.apache.log4j.Logger;

import com.twitter.data.Example;
import com.twitter.storm.primitives.LocalLearner;
import com.twitter.util.MathUtil;

public class Learner implements Serializable {
public static Logger LOG = Logger.getLogger(LocalLearner.class);

protected double[] weights;
protected LossFunction lossFunction;
int numExamples = 0;
int numMisclassified = 0;
double totalLoss = 0.0;
double gradientSum = 0.0;
protected double learningRate = 1.0;

public Learner(int dimension) {
weights = new double[dimension];
lossFunction = new LossFunction(2);
}

public void update(Example example, int epoch) {
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
if (example.isLabeled) {
if ((double) predicted != example.label) {
double[] gradient = lossFunction.gradient(example, predicted);
gradientSum += MathUtil.l2norm(gradient);
double eta = getLearningRate(example, epoch);
MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta));
}
}
displayStats();
}

protected double getLearningRate(Example example, int timestamp) {
return learningRate / Math.sqrt(timestamp);
}

public double[] getWeights() {
return weights;
}

public double getParallelUpdateWeight() {
return gradientSum;
}

public void initWeights(double[] newWeights) {
assert (newWeights.length == weights.length);
weights = Arrays.copyOf(newWeights, newWeights.length);
}

public int predict(Example example) {
double dot = MathUtil.dot(weights, example.x);
return (dot >= 0.0) ? 1 : -1;
}

protected void updateStats(Example example, int prediction) {
numExamples++;
if (example.label != prediction)
numMisclassified++;
totalLoss += lossFunction.get(example, prediction);
}

public void displayStats() {
if (numExamples == 0) {
System.out.println("No examples seen so far.");
}
double accuracy = 1.0 - numMisclassified * 1.0 / numExamples;
double meanLoss = totalLoss * 1.0 / numExamples;
LOG.info(String.format("Accuracy: %g\tMean Loss: %g", accuracy, meanLoss));

}

public void resetStats() {
numExamples = 0;
numMisclassified = 0;
totalLoss = 0.0;
}
}
30 changes: 30 additions & 0 deletions storm-ml/src/main/java/com/twitter/algorithms/LossFunction.java
@@ -0,0 +1,30 @@
package com.twitter.algorithms;

import java.io.Serializable;

import com.twitter.data.Example;

public class LossFunction implements Serializable {
private double[] grad; // gradient

public LossFunction(int dimension) {
grad = new double[dimension];
}

public double get(Example e, int prediction) {
return 0.5 * (e.label - prediction) * (e.label - prediction);
}

public double[] gradient(Example e, int prediction) {
double f = -1.0 * (e.label - prediction);
for (int i = 0; i < e.x.length; i++) {
grad[i] = f * e.x[i];
}
return grad;
}

static LossFunction byName(String name, int dimension) {
return new LossFunction(dimension);
}

}

0 comments on commit 5b2ac55

Please sign in to comment.