Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

WIP perceptron train method

  • Loading branch information...
commit 656ea7f4df39410b2918620859d0fa8a3a3e00e3 1 parent 1c907b2
@git2samus git2samus authored git2samus committed
View
1  storm-ml/project.clj
@@ -6,4 +6,5 @@
:javac-options {:debug "true" :fork "true"}
:aot :all
:dev-dependencies [[org.clojure/clojure "1.4.0"]
+ [org.javatuples/javatuples "1.2"]
[storm "0.7.2"]])
View
40 storm-ml/src/jvm/storm/ml/Main.java
@@ -0,0 +1,40 @@
+package storm.ml;
+
+import java.lang.Boolean;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.ArrayList;
+
+import org.javatuples.Pair;
+
+import storm.ml.PerceptronTopologyBuilder;
+
+public class Main {
+ public static void main(String[] keywords) {
+ List<Pair<List<BigDecimal>, Boolean>> training_set = new ArrayList<Pair<List<BigDecimal>, Boolean>>(4);
+ List<BigDecimal> input_vector = new ArrayList<BigDecimal>(3);
+
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(0));
+ input_vector.add(new BigDecimal(0));
+ training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
+
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(0));
+ input_vector.add(new BigDecimal(1));
+ training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
+
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(0));
+ training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, true));
+
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(1));
+ input_vector.add(new BigDecimal(1));
+ training_set.add(new Pair<List<BigDecimal>, Boolean>(input_vector, false));
+
+ PerceptronTopologyBuilder ptb = new PerceptronTopologyBuilder(3, 0.5, 0.1);
+ ptb.train(training_set);
+ }
+}
View
59 storm-ml/src/jvm/storm/ml/PerceptronTopologyBuilder.java
@@ -0,0 +1,59 @@
+package storm.ml;
+
+import java.lang.Integer;
+import java.lang.Double;
+import java.lang.Boolean;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.ArrayList;
+
+import org.javatuples.Pair;
+
+public class PerceptronTopologyBuilder {
+ public final Integer size;
+ public final Double threshold;
+ public final Double learning_rate;
+
+ private List<BigDecimal> weights;
+
+ public PerceptronTopologyBuilder(Integer size, Double threshold, Double learning_rate) {
+ this.size = size; // size of the weight array and input
+ this.threshold = threshold; // margin to determine positive results
+ this.learning_rate = learning_rate; // adaptation factor for the weights
+
+ this.weights = new ArrayList<BigDecimal>(size);
+ int i; for (i=0; i<size; i++)
+ this.weights.add(new BigDecimal(0));
+ }
+
+ private BigDecimal dot_product(List<BigDecimal> vector1, List<BigDecimal> vector2) {
+ BigDecimal result = new BigDecimal(0);
+ int i; for (i=0; i<this.size; i++)
+ result.add(vector1.get(i).multiply(vector2.get(i)));
+
+ return result;
+ }
+
+ public void train(List<Pair<List<BigDecimal>, Boolean>> training_set) {
+ while (true) {
+ int error_count = 0;
+ for (Pair<List<BigDecimal>, Boolean> training_pair : training_set) {
+ List<BigDecimal> input_vector = training_pair.getValue0();
+ Integer desired_output = training_pair.getValue1() ? 1 : 0;
+ System.out.println(String.format("%s", this.weights));
+
+ int result = dot_product(input_vector, this.weights).compareTo(new BigDecimal(threshold)) > 0 ? 1 : 0;
+
+ int error = desired_output - result;
+ if (error != 0) {
+ error_count += 1;
+ int i; for (i=0; i<this.size; i++)
+ this.weights.set(i, this.weights.get(i).add(input_vector.get(i).multiply(new BigDecimal(this.learning_rate * error))));
+ }
+ }
+ if (error_count == 0)
+ break;
+ System.out.println();
+ }
+ }
+}
Please sign in to comment.
Something went wrong with that request. Please try again.