Skip to content

Commit

Permalink
WIP perceptron train method
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Cetrulo authored and git2samus committed Jun 6, 2012
1 parent 1c907b2 commit 656ea7f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions storm-ml/project.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
:javac-options {:debug "true" :fork "true"}
:aot :all
:dev-dependencies [[org.clojure/clojure "1.4.0"]
[org.javatuples/javatuples "1.2"]
[storm "0.7.2"]])
40 changes: 40 additions & 0 deletions storm-ml/src/jvm/storm/ml/Main.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package storm.ml;

import java.lang.Boolean;
import java.math.BigDecimal;
import java.util.List;
import java.util.ArrayList;

import org.javatuples.Pair;

import storm.ml.PerceptronTopologyBuilder;

public class Main {
public static void main(String[] keywords) {
List<Pair<List<BigDecimal>, Boolean>> training_set = new ArrayList<Pair<List<BigDecimal>, Boolean>>(4);
List<BigDecimal> input_vector = new ArrayList<BigDecimal>(3);

input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(0));
input_vector.add(new BigDecimal(0));
training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));

input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(0));
input_vector.add(new BigDecimal(1));
training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));

input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(0));
training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));

input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(1));
input_vector.add(new BigDecimal(1));
training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, false));

PerceptronTopologyBuilder ptb = new PerceptronTopologyBuilder(3, 0.5, 0.1);
ptb.train(training_set);
}
}
59 changes: 59 additions & 0 deletions storm-ml/src/jvm/storm/ml/PerceptronTopologyBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package storm.ml;

import java.lang.Integer;
import java.lang.Double;
import java.lang.Boolean;
import java.math.BigDecimal;
import java.util.List;
import java.util.ArrayList;

import org.javatuples.Pair;

public class PerceptronTopologyBuilder {
public final Integer size;
public final Double threshold;
public final Double learning_rate;

private List<BigDecimal> weights;

public PerceptronTopologyBuilder(Integer size, Double threshold, Double learning_rate) {
this.size = size; // size of the weight array and input
this.threshold = threshold; // margin to determine positive results
this.learning_rate = learning_rate; // adaptation factor for the weights

this.weights = new ArrayList<BigDecimal>(size);
int i; for (i=0; i<size; i++)
this.weights.add(new BigDecimal(0));
}

private BigDecimal dot_product(List<BigDecimal> vector1, List<BigDecimal> vector2) {
BigDecimal result = new BigDecimal(0);
int i; for (i=0; i<this.size; i++)
result.add(vector1.get(i).multiply(vector2.get(i)));

return result;
}

public void train(List<Pair<List<BigDecimal>, Boolean>> training_set) {
while (true) {
int error_count = 0;
for (Pair<List<BigDecimal>, Boolean> training_pair : training_set) {
List<BigDecimal> input_vector = training_pair.getValue0();
Integer desired_output = training_pair.getValue1() ? 1 : 0;
System.out.println(String.format("%s", this.weights));

int result = dot_product(input_vector, this.weights).compareTo(new BigDecimal(threshold)) > 0 ? 1 : 0;

int error = desired_output - result;
if (error != 0) {
error_count += 1;
int i; for (i=0; i<this.size; i++)
this.weights.set(i, this.weights.get(i).add(input_vector.get(i).multiply(new BigDecimal(this.learning_rate * error))));
}
}
if (error_count == 0)
break;
System.out.println();
}
}
}

0 comments on commit 656ea7f

Please sign in to comment.