Permalink
Browse files

fixed a bug in PA1a and PA2a that stddev was not calculated correctly.

  • Loading branch information...
1 parent 53c3dd6 commit 75ccdd336cf109866a7910c1f1095dd4ebfb5128 @myui committed Oct 5, 2013
Showing with 20 additions and 6 deletions.
  1. +20 −6 src/main/hivemall/regression/PassiveAggressiveRegressionUDTF.java
@@ -94,6 +94,8 @@ protected float aggressiveness() {
@Override
protected void train(Map<Object, FloatWritable> weights, Collection<?> features, float target) {
+ preTrain(target);
+
PredictionResult margin = calcScore(features);
float predicted = margin.getScore();
float loss = loss(target, predicted);
@@ -108,6 +110,8 @@ protected void train(Map<Object, FloatWritable> weights, Collection<?> features,
}
}
+ protected void preTrain(float target) {}
+
/**
* |w^t - y| - epsilon
*/
@@ -126,18 +130,23 @@ protected float eta(float loss, PredictionResult margin) {
public static class PA1a extends PassiveAggressiveRegressionUDTF {
- private OnlineVariance target_stddev;
+ private OnlineVariance targetStdDev;
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
- this.target_stddev = new OnlineVariance();
+ this.targetStdDev = new OnlineVariance();
return super.initialize(argOIs);
}
@Override
+ protected void preTrain(float target) {
+ targetStdDev.handle(target);
+ }
+
+ @Override
protected float loss(float target, float predicted) {
- float stddev = (float) target_stddev.stddev();
+ float stddev = (float) targetStdDev.stddev();
float e = epsilon * stddev;
return EpsilonInsensitiveLoss.loss(predicted, target, e);
}
@@ -162,18 +171,23 @@ protected float eta(float loss, PredictionResult margin) {
public static class PA2a extends PA2 {
- private OnlineVariance target_stddev;
+ private OnlineVariance targetStdDev;
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
- this.target_stddev = new OnlineVariance();
+ this.targetStdDev = new OnlineVariance();
return super.initialize(argOIs);
}
@Override
+ protected void preTrain(float target) {
+ targetStdDev.handle(target);
+ }
+
+ @Override
protected float loss(float target, float predicted) {
- float stddev = (float) target_stddev.stddev();
+ float stddev = (float) targetStdDev.stddev();
float e = epsilon * stddev;
return EpsilonInsensitiveLoss.loss(predicted, target, e);
}

0 comments on commit 75ccdd3

Please sign in to comment.