Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Added AROW with a hinge loss (arowh_regress())
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Oct 21, 2013
1 parent dff1d86 commit 239e90d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
3 changes: 3 additions & 0 deletions scripts/ddl/define-all.hive
Expand Up @@ -159,6 +159,9 @@ create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressive
drop temporary function arow_regress; drop temporary function arow_regress;
create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF'; create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';


drop temporary function arowh_regress;
create temporary function arowh_regress as 'hivemall.regression.AROWRegressionUDTF$AROWh';

--------------------- ---------------------
-- array functions -- -- array functions --
--------------------- ---------------------
Expand Down
5 changes: 4 additions & 1 deletion scripts/ddl/define-regression-udf.hive
Expand Up @@ -35,4 +35,7 @@ drop temporary function pa2a_regress;
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a'; create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';


drop temporary function arow_regress; drop temporary function arow_regress;
create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF'; create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';

drop temporary function arowh_regress;
create temporary function arowh_regress as 'hivemall.regression.AROWRegressionUDTF$AROWh';
61 changes: 56 additions & 5 deletions src/main/hivemall/regression/AROWRegressionUDTF.java
Expand Up @@ -21,6 +21,7 @@
package hivemall.regression; package hivemall.regression;


import hivemall.common.FeatureValue; import hivemall.common.FeatureValue;
import hivemall.common.LossFunctions.EpsilonInsensitiveLoss;
import hivemall.common.PredictionResult; import hivemall.common.PredictionResult;
import hivemall.common.WeightValue; import hivemall.common.WeightValue;


Expand Down Expand Up @@ -97,7 +98,7 @@ protected float loss(float target, float predicted) {
} }


@Override @Override
protected void update(final Collection<?> features, final float loss, final float beta) { protected void update(final Collection<?> features, final float coeff, final float beta) {
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector(); final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();


for(Object f : features) { for(Object f : features) {
Expand All @@ -112,18 +113,18 @@ protected void update(final Collection<?> features, final float loss, final floa
v = 1.f; v = 1.f;
} }
WeightValue old_w = weights.get(k); WeightValue old_w = weights.get(k);
WeightValue new_w = getNewWeight(old_w, v, loss, beta); WeightValue new_w = getNewWeight(old_w, v, coeff, beta);
weights.put(k, new_w); weights.put(k, new_w);
} }


if(biasKey != null) { if(biasKey != null) {
WeightValue old_bias = weights.get(biasKey); WeightValue old_bias = weights.get(biasKey);
WeightValue new_bias = getNewWeight(old_bias, bias, loss, beta); WeightValue new_bias = getNewWeight(old_bias, bias, coeff, beta);
weights.put(biasKey, new_bias); weights.put(biasKey, new_bias);
} }
} }


private static WeightValue getNewWeight(final WeightValue old, final float x, final float loss, final float beta) { private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float beta) {
final float old_w; final float old_w;
final float old_cov; final float old_cov;
if(old == null) { if(old == null) {
Expand All @@ -135,10 +136,60 @@ private static WeightValue getNewWeight(final WeightValue old, final float x, fi
} }


float cov_x = old_cov * x; float cov_x = old_cov * x;
float new_w = old_w + loss * cov_x * beta; float new_w = old_w + coeff * cov_x * beta;
float new_cov = old_cov - (beta * cov_x * cov_x); float new_cov = old_cov - (beta * cov_x * cov_x);


return new WeightValue(new_w, new_cov); return new WeightValue(new_w, new_cov);
} }


public static class AROWh extends AROWRegressionUDTF {

/** Sensitivity to prediction mistakes */
protected float epsilon;

@Override
protected Options getOptions() {
Options opts = super.getOptions();
opts.addOption("e", "epsilon", true, "Sensitivity to prediction mistakes [default 0.1]");
return opts;
}

@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
CommandLine cl = super.processOptions(argOIs);

float epsilon = 0.1f;
if(cl != null) {
String opt_epsilon = cl.getOptionValue("epsilon");
if(opt_epsilon != null) {
epsilon = Float.parseFloat(opt_epsilon);
}
}

this.epsilon = epsilon;
return cl;
}

@Override
protected void train(Collection<?> features, float target) {
PredictionResult margin = calcScoreAndVariance(features);
float predicted = margin.getScore();

float loss = loss(target, predicted);
if(loss > 0.f) {
float coeff = (target - predicted) > 0.f ? loss : -loss;
float var = margin.getVariance();
float beta = 1.f / (var + r);
update(features, coeff, beta);
}
}

/**
* |w^t - y| - epsilon
*/
protected float loss(float target, float predicted) {
return EpsilonInsensitiveLoss.loss(predicted, target, epsilon);
}
}

} }
Binary file modified target/hivemall.jar
Binary file not shown.

0 comments on commit 239e90d

Please sign in to comment.