Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Merge pull request #42 from myui/feature/arow_regression
Browse files Browse the repository at this point in the history
Feature/arow regression
  • Loading branch information
myui committed Oct 21, 2013
2 parents 996bc0f + 816075c commit f2f00a2
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 27 deletions.
3 changes: 3 additions & 0 deletions scripts/ddl/define-all.hive
Expand Up @@ -156,6 +156,9 @@ create temporary function pa2_regress as 'hivemall.regression.PassiveAggressiveR
drop temporary function pa2a_regress;
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';

drop temporary function arow_regress;
create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';

---------------------
-- array functions --
---------------------
Expand Down
5 changes: 4 additions & 1 deletion scripts/ddl/define-regression-udf.hive
Expand Up @@ -32,4 +32,7 @@ drop temporary function pa2_regress;
create temporary function pa2_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2';

drop temporary function pa2a_regress;
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';
create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a';

drop temporary function arow_regress;
create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF';
8 changes: 4 additions & 4 deletions src/main/hivemall/classifier/AROWClassifierUDTF.java
Expand Up @@ -129,18 +129,18 @@ protected void update(final List<?> features, final int y, final float alpha, fi
}

private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) {
final float old_v;
final float old_w;
final float old_cov;
if(old == null) {
old_v = 0.f;
old_w = 0.f;
old_cov = 1.f;
} else {
old_v = old.getValue();
old_w = old.getValue();
old_cov = old.getCovariance();
}

float cv = old_cov * x;
float new_w = old_v + (y * alpha * cv);
float new_w = old_w + (y * alpha * cv);
float new_cov = old_cov - (beta * cv * cv);

return new WeightValue(new_w, new_cov);
Expand Down
144 changes: 144 additions & 0 deletions src/main/hivemall/regression/AROWRegressionUDTF.java
@@ -0,0 +1,144 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2013
* National Institute of Advanced Industrial Science and Technology (AIST)
* Registration Number: H25PRO-1520
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
package hivemall.regression;

import hivemall.common.FeatureValue;
import hivemall.common.PredictionResult;
import hivemall.common.WeightValue;

import java.util.Collection;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

public class AROWRegressionUDTF extends OnlineRegressionUDTF {

/** Regularization parameter r */
protected float r;

@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
if(numArgs != 2 && numArgs != 3) {
throw new UDFArgumentException(getClass().getSimpleName()
+ " takes arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
}

return super.initialize(argOIs);
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
opts.addOption("r", "regularization", true, "Regularization parameter for some r > 0 [default 0.1]");
return opts;
}

@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
final CommandLine cl = super.processOptions(argOIs);

float r = 0.1f;
if(cl != null) {
String r_str = cl.getOptionValue("r");
if(r_str != null) {
r = Float.parseFloat(r_str);
if(!(r > 0)) {
throw new UDFArgumentException("Regularization parameter must be greater than 0: "
+ r_str);
}
}
}

this.r = r;
return cl;
}

@Override
protected void train(Collection<?> features, float target) {
PredictionResult margin = calcScoreAndVariance(features);
float predicted = margin.getScore();

float loss = loss(target, predicted);

float var = margin.getVariance();
float beta = 1.f / (var + r);

update(features, loss, beta);
}

/**
* @return target - predicted
*/
protected float loss(float target, float predicted) {
return target - predicted; // y - m^Tx
}

@Override
protected void update(final Collection<?> features, final float loss, final float beta) {
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();

for(Object f : features) {
final Object k;
final float v;
if(parseX) {
FeatureValue fv = FeatureValue.parse(f, feature_hashing);
k = fv.getFeature();
v = fv.getValue();
} else {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
WeightValue old_w = weights.get(k);
WeightValue new_w = getNewWeight(old_w, v, loss, beta);
weights.put(k, new_w);
}

if(biasKey != null) {
WeightValue old_bias = weights.get(biasKey);
WeightValue new_bias = getNewWeight(old_bias, bias, loss, beta);
weights.put(biasKey, new_bias);
}
}

private static WeightValue getNewWeight(final WeightValue old, final float x, final float loss, final float beta) {
final float old_w;
final float old_cov;
if(old == null) {
old_w = 0.f;
old_cov = 1.f;
} else {
old_w = old.getValue();
old_cov = old.getCovariance();
}

float cov_x = old_cov * x;
float new_w = old_w + loss * cov_x * beta;
float new_cov = old_cov - (beta * cov_x * cov_x);

return new WeightValue(new_w, new_cov);
}

}
6 changes: 4 additions & 2 deletions src/main/hivemall/regression/LogressIterUDTF.java
Expand Up @@ -21,6 +21,7 @@
package hivemall.regression;

import hivemall.common.HivemallConstants;
import hivemall.common.WeightValue;

import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -81,12 +82,13 @@ public void process(Object[] args) throws HiveException {
Object k = e.getKey();
Object feature = ObjectInspectorUtils.copyToStandardObject(k, featureInspector);
if(!weights.containsKey(feature)) {
weights.put(feature, weight);
float v = weight.get();
weights.put(feature, new WeightValue(v));
}
}

Set<Object> features = featuresWithWeight.keySet();
train(weights, features, target);
train(features, target);
count++;
}

Expand Down
74 changes: 58 additions & 16 deletions src/main/hivemall/regression/OnlineRegressionUDTF.java
Expand Up @@ -23,6 +23,7 @@
import hivemall.common.FeatureValue;
import hivemall.common.HivemallConstants;
import hivemall.common.PredictionResult;
import hivemall.common.WeightValue;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -60,7 +61,7 @@ public abstract class OnlineRegressionUDTF extends GenericUDTF {
protected float bias;
protected Object biasKey;

protected Map<Object, FloatWritable> weights;
protected Map<Object, WeightValue> weights;
protected int count;

@Override
Expand Down Expand Up @@ -95,7 +96,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

this.weights = new HashMap<Object, FloatWritable>(8192);
this.weights = new HashMap<Object, WeightValue>(8192);
this.count = 1;

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
Expand Down Expand Up @@ -168,13 +169,13 @@ public void process(Object[] args) throws HiveException {
float target = targetOI.get(args[1]);
checkTargetValue(target);

train(weights, features, target);
train(features, target);
count++;
}

protected void checkTargetValue(float target) throws UDFArgumentException {}

protected void train(final Map<Object, FloatWritable> weights, final Collection<?> features, final float target) {
protected void train(final Collection<?> features, final float target) {
float p = predict(features);
update(features, target, p);
}
Expand All @@ -195,14 +196,14 @@ protected float predict(final Collection<?> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
FloatWritable old_w = weights.get(k);
WeightValue old_w = weights.get(k);
if(old_w != null) {
score += (old_w.get() * v);
}
}

if(biasKey != null) {
FloatWritable biasWeight = weights.get(biasKey);
WeightValue biasWeight = weights.get(biasKey);
if(biasWeight != null) {
score += (biasWeight.get() * bias);
}
Expand All @@ -211,7 +212,7 @@ protected float predict(final Collection<?> features) {
return score;
}

protected PredictionResult calcScore(Collection<?> features) {
protected PredictionResult calcScoreAndNorm(Collection<?> features) {
final ObjectInspector featureInspector = this.featureInputOI;
final boolean parseX = this.parseX;

Expand All @@ -229,15 +230,15 @@ protected PredictionResult calcScore(Collection<?> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
FloatWritable old_w = weights.get(k);
WeightValue old_w = weights.get(k);
if(old_w != null) {
score += (old_w.get() * v);
}
squared_norm += (v * v);
}

if(biasKey != null) {
FloatWritable biasWeight = weights.get(biasKey);
WeightValue biasWeight = weights.get(biasKey);
if(biasWeight != null) {
score += (biasWeight.get() * bias);
}
Expand All @@ -247,6 +248,46 @@ protected PredictionResult calcScore(Collection<?> features) {
return new PredictionResult(score).squaredNorm(squared_norm);
}

protected PredictionResult calcScoreAndVariance(Collection<?> features) {
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
final boolean parseX = this.parseX;

float score = 0.f;
float variance = 0.f;

for(Object f : features) {// a += w[i] * x[i]
final Object k;
final float v;
if(parseX) {
FeatureValue fv = FeatureValue.parse(f, feature_hashing);
k = fv.getFeature();
v = fv.getValue();
} else {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
WeightValue old_w = weights.get(k);
if(old_w == null) {
variance += (1.f * v * v);
} else {
score += (old_w.getValue() * v);
variance += (old_w.getCovariance() * v * v);
}
}

if(biasKey != null) {
WeightValue biasWeight = weights.get(biasKey);
if(biasWeight == null) {
variance += (1.f * bias * bias);
} else {
score += (biasWeight.getValue() * bias);
variance += (biasWeight.getCovariance() * bias * bias);
}
}

return new PredictionResult(score).variance(variance);
}

protected void update(Collection<?> features, float target, float predicted) {
float d = dloss(target, predicted);
update(features, d);
Expand All @@ -270,27 +311,28 @@ protected void update(Collection<?> features, float coeff) {
x = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
xi = 1.f;
}
FloatWritable old_w = weights.get(x);
WeightValue old_w = weights.get(x);
float new_w = (old_w == null) ? coeff * xi : old_w.get() + (coeff * xi);
weights.put(x, new FloatWritable(new_w));
weights.put(x, new WeightValue(new_w));
}

if(biasKey != null) {
FloatWritable old_bias = weights.get(biasKey);
WeightValue old_bias = weights.get(biasKey);
float new_bias = (old_bias == null) ? coeff * bias : old_bias.get() + (coeff * bias);
weights.put(biasKey, new FloatWritable(new_bias));
weights.put(biasKey, new WeightValue(new_bias));
}
}

@Override
public void close() throws HiveException {
if(weights != null) {
final Object[] forwardMapObj = new Object[2];
for(Map.Entry<Object, FloatWritable> e : weights.entrySet()) {
for(Map.Entry<Object, WeightValue> e : weights.entrySet()) {
Object k = e.getKey();
FloatWritable v = e.getValue();
WeightValue v = e.getValue();
FloatWritable fv = new FloatWritable(v.get());
forwardMapObj[0] = k;
forwardMapObj[1] = v;
forwardMapObj[1] = fv;
forward(forwardMapObj);
}
this.weights = null;
Expand Down

0 comments on commit f2f00a2

Please sign in to comment.