Permalink
Browse files

Merge pull request #42 from myui/feature/arow_regression

Feature/arow regression
  • Loading branch information...
2 parents 996bc0f + 816075c commit f2f00a27820685525c26c5d457189aba04db0600 @myui committed Oct 21, 2013
@@ -156,6 +156,9 @@ create temporary function pa2_regress as 'hivemall.regression.PassiveAggressiveR
drop temporary function pa2a_regress;
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';
+drop temporary function arow_regress;
+create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
+
---------------------
-- array functions --
---------------------
@@ -32,4 +32,7 @@ drop temporary function pa2_regress;
create temporary function pa2_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2';
drop temporary function pa2a_regress;
-create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';
+create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';
+
+drop temporary function arow_regress;
+create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
@@ -129,18 +129,18 @@ protected void update(final List<?> features, final int y, final float alpha, fi
}
private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) {
- final float old_v;
+ final float old_w;
final float old_cov;
if(old == null) {
- old_v = 0.f;
+ old_w = 0.f;
old_cov = 1.f;
} else {
- old_v = old.getValue();
+ old_w = old.getValue();
old_cov = old.getCovariance();
}
float cv = old_cov * x;
- float new_w = old_v + (y * alpha * cv);
+ float new_w = old_w + (y * alpha * cv);
float new_cov = old_cov - (beta * cv * cv);
return new WeightValue(new_w, new_cov);
@@ -0,0 +1,144 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.regression;
+
+import hivemall.common.FeatureValue;
+import hivemall.common.PredictionResult;
+import hivemall.common.WeightValue;
+
+import java.util.Collection;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+public class AROWRegressionUDTF extends OnlineRegressionUDTF {
+
+ /** Regularization parameter r */
+ protected float r;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if(numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
+ }
+
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("r", "regularization", true, "Regularization parameter for some r > 0 [default 0.1]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final CommandLine cl = super.processOptions(argOIs);
+
+ float r = 0.1f;
+ if(cl != null) {
+ String r_str = cl.getOptionValue("r");
+ if(r_str != null) {
+ r = Float.parseFloat(r_str);
+ if(!(r > 0)) {
+ throw new UDFArgumentException("Regularization parameter must be greater than 0: "
+ + r_str);
+ }
+ }
+ }
+
+ this.r = r;
+ return cl;
+ }
+
+ @Override
+ protected void train(Collection<?> features, float target) {
+ PredictionResult margin = calcScoreAndVariance(features);
+ float predicted = margin.getScore();
+
+ float loss = loss(target, predicted);
+
+ float var = margin.getVariance();
+ float beta = 1.f / (var + r);
+
+ update(features, loss, beta);
+ }
+
+ /**
+ * @return target - predicted
+ */
+ protected float loss(float target, float predicted) {
+ return target - predicted; // y - m^Tx
+ }
+
+ @Override
+ protected void update(final Collection<?> features, final float loss, final float beta) {
+ final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
+
+ for(Object f : features) {
+ final Object k;
+ final float v;
+ if(parseX) {
+ FeatureValue fv = FeatureValue.parse(f, feature_hashing);
+ k = fv.getFeature();
+ v = fv.getValue();
+ } else {
+ k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
+ v = 1.f;
+ }
+ WeightValue old_w = weights.get(k);
+ WeightValue new_w = getNewWeight(old_w, v, loss, beta);
+ weights.put(k, new_w);
+ }
+
+ if(biasKey != null) {
+ WeightValue old_bias = weights.get(biasKey);
+ WeightValue new_bias = getNewWeight(old_bias, bias, loss, beta);
+ weights.put(biasKey, new_bias);
+ }
+ }
+
+ private static WeightValue getNewWeight(final WeightValue old, final float x, final float loss, final float beta) {
+ final float old_w;
+ final float old_cov;
+ if(old == null) {
+ old_w = 0.f;
+ old_cov = 1.f;
+ } else {
+ old_w = old.getValue();
+ old_cov = old.getCovariance();
+ }
+
+ float cov_x = old_cov * x;
+ float new_w = old_w + loss * cov_x * beta;
+ float new_cov = old_cov - (beta * cov_x * cov_x);
+
+ return new WeightValue(new_w, new_cov);
+ }
+
+}
@@ -21,6 +21,7 @@
package hivemall.regression;
import hivemall.common.HivemallConstants;
+import hivemall.common.WeightValue;
import java.util.Map;
import java.util.Set;
@@ -81,12 +82,13 @@ public void process(Object[] args) throws HiveException {
Object k = e.getKey();
Object feature = ObjectInspectorUtils.copyToStandardObject(k, featureInspector);
if(!weights.containsKey(feature)) {
- weights.put(feature, weight);
+ float v = weight.get();
+ weights.put(feature, new WeightValue(v));
}
}
Set<Object> features = featuresWithWeight.keySet();
- train(weights, features, target);
+ train(features, target);
count++;
}
@@ -23,6 +23,7 @@
import hivemall.common.FeatureValue;
import hivemall.common.HivemallConstants;
import hivemall.common.PredictionResult;
+import hivemall.common.WeightValue;
import java.util.ArrayList;
import java.util.Collection;
@@ -60,7 +61,7 @@
protected float bias;
protected Object biasKey;
- protected Map<Object, FloatWritable> weights;
+ protected Map<Object, WeightValue> weights;
protected int count;
@Override
@@ -95,7 +96,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
- this.weights = new HashMap<Object, FloatWritable>(8192);
+ this.weights = new HashMap<Object, WeightValue>(8192);
this.count = 1;
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
@@ -168,13 +169,13 @@ public void process(Object[] args) throws HiveException {
float target = targetOI.get(args[1]);
checkTargetValue(target);
- train(weights, features, target);
+ train(features, target);
count++;
}
protected void checkTargetValue(float target) throws UDFArgumentException {}
- protected void train(final Map<Object, FloatWritable> weights, final Collection<?> features, final float target) {
+ protected void train(final Collection<?> features, final float target) {
float p = predict(features);
update(features, target, p);
}
@@ -195,14 +196,14 @@ protected float predict(final Collection<?> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- FloatWritable old_w = weights.get(k);
+ WeightValue old_w = weights.get(k);
if(old_w != null) {
score += (old_w.get() * v);
}
}
if(biasKey != null) {
- FloatWritable biasWeight = weights.get(biasKey);
+ WeightValue biasWeight = weights.get(biasKey);
if(biasWeight != null) {
score += (biasWeight.get() * bias);
}
@@ -211,7 +212,7 @@ protected float predict(final Collection<?> features) {
return score;
}
- protected PredictionResult calcScore(Collection<?> features) {
+ protected PredictionResult calcScoreAndNorm(Collection<?> features) {
final ObjectInspector featureInspector = this.featureInputOI;
final boolean parseX = this.parseX;
@@ -229,15 +230,15 @@ protected PredictionResult calcScore(Collection<?> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- FloatWritable old_w = weights.get(k);
+ WeightValue old_w = weights.get(k);
if(old_w != null) {
score += (old_w.get() * v);
}
squared_norm += (v * v);
}
if(biasKey != null) {
- FloatWritable biasWeight = weights.get(biasKey);
+ WeightValue biasWeight = weights.get(biasKey);
if(biasWeight != null) {
score += (biasWeight.get() * bias);
}
@@ -247,6 +248,46 @@ protected PredictionResult calcScore(Collection<?> features) {
return new PredictionResult(score).squaredNorm(squared_norm);
}
+ protected PredictionResult calcScoreAndVariance(Collection<?> features) {
+ final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
+ final boolean parseX = this.parseX;
+
+ float score = 0.f;
+ float variance = 0.f;
+
+ for(Object f : features) {// a += w[i] * x[i]
+ final Object k;
+ final float v;
+ if(parseX) {
+ FeatureValue fv = FeatureValue.parse(f, feature_hashing);
+ k = fv.getFeature();
+ v = fv.getValue();
+ } else {
+ k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
+ v = 1.f;
+ }
+ WeightValue old_w = weights.get(k);
+ if(old_w == null) {
+ variance += (1.f * v * v);
+ } else {
+ score += (old_w.getValue() * v);
+ variance += (old_w.getCovariance() * v * v);
+ }
+ }
+
+ if(biasKey != null) {
+ WeightValue biasWeight = weights.get(biasKey);
+ if(biasWeight == null) {
+ variance += (1.f * bias * bias);
+ } else {
+ score += (biasWeight.getValue() * bias);
+ variance += (biasWeight.getCovariance() * bias * bias);
+ }
+ }
+
+ return new PredictionResult(score).variance(variance);
+ }
+
protected void update(Collection<?> features, float target, float predicted) {
float d = dloss(target, predicted);
update(features, d);
@@ -270,27 +311,28 @@ protected void update(Collection<?> features, float coeff) {
x = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
xi = 1.f;
}
- FloatWritable old_w = weights.get(x);
+ WeightValue old_w = weights.get(x);
float new_w = (old_w == null) ? coeff * xi : old_w.get() + (coeff * xi);
- weights.put(x, new FloatWritable(new_w));
+ weights.put(x, new WeightValue(new_w));
}
if(biasKey != null) {
- FloatWritable old_bias = weights.get(biasKey);
+ WeightValue old_bias = weights.get(biasKey);
float new_bias = (old_bias == null) ? coeff * bias : old_bias.get() + (coeff * bias);
- weights.put(biasKey, new FloatWritable(new_bias));
+ weights.put(biasKey, new WeightValue(new_bias));
}
}
@Override
public void close() throws HiveException {
if(weights != null) {
final Object[] forwardMapObj = new Object[2];
- for(Map.Entry<Object, FloatWritable> e : weights.entrySet()) {
+ for(Map.Entry<Object, WeightValue> e : weights.entrySet()) {
Object k = e.getKey();
- FloatWritable v = e.getValue();
+ WeightValue v = e.getValue();
+ FloatWritable fv = new FloatWritable(v.get());
forwardMapObj[0] = k;
- forwardMapObj[1] = v;
+ forwardMapObj[1] = fv;
forward(forwardMapObj);
}
this.weights = null;
Oops, something went wrong.

0 comments on commit f2f00a2

Please sign in to comment.