Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added AROW with a hinge loss (arowh_regress())

  • Loading branch information...
commit 239e90d537c1094e6b6003e712b0bd0f7b624eb7 1 parent dff1d86
@myui authored
View
3  scripts/ddl/define-all.hive
@@ -159,6 +159,9 @@ create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressive
drop temporary function arow_regress;
create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
+drop temporary function arowh_regress;
+create temporary function arowh_regress as 'hivemall.regression.AROWRegressionUDTF$AROWh';
+
---------------------
-- array functions --
---------------------
View
5 scripts/ddl/define-regression-udf.hive
@@ -35,4 +35,7 @@ drop temporary function pa2a_regress;
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';
drop temporary function arow_regress;
-create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
+create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
+
+drop temporary function arowh_regress;
+create temporary function arowh_regress as 'hivemall.regression.AROWRegressionUDTF$AROWh';
View
61 src/main/hivemall/regression/AROWRegressionUDTF.java
@@ -21,6 +21,7 @@
package hivemall.regression;
import hivemall.common.FeatureValue;
+import hivemall.common.LossFunctions.EpsilonInsensitiveLoss;
import hivemall.common.PredictionResult;
import hivemall.common.WeightValue;
@@ -97,7 +98,7 @@ protected float loss(float target, float predicted) {
}
@Override
- protected void update(final Collection<?> features, final float loss, final float beta) {
+ protected void update(final Collection<?> features, final float coeff, final float beta) {
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
for(Object f : features) {
@@ -112,18 +113,18 @@ protected void update(final Collection<?> features, final float loss, final floa
v = 1.f;
}
WeightValue old_w = weights.get(k);
- WeightValue new_w = getNewWeight(old_w, v, loss, beta);
+ WeightValue new_w = getNewWeight(old_w, v, coeff, beta);
weights.put(k, new_w);
}
if(biasKey != null) {
WeightValue old_bias = weights.get(biasKey);
- WeightValue new_bias = getNewWeight(old_bias, bias, loss, beta);
+ WeightValue new_bias = getNewWeight(old_bias, bias, coeff, beta);
weights.put(biasKey, new_bias);
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float loss, final float beta) {
+ private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float beta) {
final float old_w;
final float old_cov;
if(old == null) {
@@ -135,10 +136,60 @@ private static WeightValue getNewWeight(final WeightValue old, final float x, fi
}
float cov_x = old_cov * x;
- float new_w = old_w + loss * cov_x * beta;
+ float new_w = old_w + coeff * cov_x * beta;
float new_cov = old_cov - (beta * cov_x * cov_x);
return new WeightValue(new_w, new_cov);
}
+ public static class AROWh extends AROWRegressionUDTF {
+
+ /** Sensitivity to prediction mistakes */
+ protected float epsilon;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("e", "epsilon", true, "Sensitivity to prediction mistakes [default 0.1]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+
+ float epsilon = 0.1f;
+ if(cl != null) {
+ String opt_epsilon = cl.getOptionValue("epsilon");
+ if(opt_epsilon != null) {
+ epsilon = Float.parseFloat(opt_epsilon);
+ }
+ }
+
+ this.epsilon = epsilon;
+ return cl;
+ }
+
+ @Override
+ protected void train(Collection<?> features, float target) {
+ PredictionResult margin = calcScoreAndVariance(features);
+ float predicted = margin.getScore();
+
+ float loss = loss(target, predicted);
+ if(loss > 0.f) {
+ float coeff = (target - predicted) > 0.f ? loss : -loss;
+ float var = margin.getVariance();
+ float beta = 1.f / (var + r);
+ update(features, coeff, beta);
+ }
+ }
+
+ /**
+ * |w^t - y| - epsilon
+ */
+ protected float loss(float target, float predicted) {
+ return EpsilonInsensitiveLoss.loss(predicted, target, epsilon);
+ }
+ }
+
}
View
BIN  target/hivemall.jar
Binary file not shown
Please sign in to comment.
Something went wrong with that request. Please try again.