forked from nathanmarz/storm-contrib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Learner.java
109 lines (90 loc) · 3.4 KB
/
Learner.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package com.twitter.algorithms;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.spy.memcached.MemcachedClient;
import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import com.twitter.data.Example;
import com.twitter.storm.primitives.LocalLearner;
import com.twitter.util.Datautil;
import com.twitter.util.MathUtil;
public class Learner implements Serializable {
public static Logger LOG = Logger.getLogger(LocalLearner.class);
protected double[] weights;
protected LossFunction lossFunction;
int numExamples = 0;
int numMisclassified = 0;
double totalLoss = 0.0;
double gradientSum = 0.0;
protected double learningRate = 0.0;
public Learner(int dimension) {
weights = new double[dimension];
lossFunction = new LossFunction(2);
}
public void update(Example example, int epoch, MemcachedClient memcache) {
String cas_weights = (String) memcache.get("model");
List<Double> weights = Datautil.parse_str_vector(cas_weights);
Double[] weights_double = weights.toArray(new Double[weights.size()]);
this.setWeights(ArrayUtils.toPrimitive(weights_double));
LOG.error("double weights" + weights_double[0]);
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
if (example.isLabeled) {
if ((double) predicted != example.label) {
List<Double> gradient = lossFunction.gradient(example, predicted);
gradientSum += MathUtil.l2norm(gradient);
double eta = getLearningRate(example, epoch);
MathUtil.plus(weights, MathUtil.times(gradient, -1.0 * eta));
}
}
displayStats();
}
protected double getLearningRate(Example example, int timestamp) {
return learningRate / Math.sqrt(timestamp);
}
public double[] getWeights() {
return weights;
}
public List<Object> getWeightsArray() {
List<Object> weight_array = new ArrayList<Object>();
for (double weight : weights)
weight_array.add(weight);
return weight_array;
}
public double getParallelUpdateWeight() {
return gradientSum;
}
public void initWeights(double[] newWeights) {
assert (newWeights.length == weights.length);
weights = Arrays.copyOf(newWeights, newWeights.length);
}
public int predict(Example example) {
double dot = MathUtil.dot(weights, example.x);
return (dot >= 0.0) ? 1 : -1;
}
protected void updateStats(Example example, int prediction) {
numExamples++;
if (example.label != prediction)
numMisclassified++;
totalLoss += lossFunction.get(example, prediction);
}
public void setWeights(double[] weights) {
this.weights = weights;
}
public void displayStats() {
if (numExamples == 0) {
System.out.println("No examples seen so far.");
}
double accuracy = 1.0 - numMisclassified * 1.0 / numExamples;
double meanLoss = totalLoss * 1.0 / numExamples;
LOG.info(String.format("Accuracy: %g\tMean Loss: %g", accuracy, meanLoss));
}
public void resetStats() {
numExamples = 0;
numMisclassified = 0;
totalLoss = 0.0;
}
}