Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Fixed UDFs to return Hive values instead of Java primitive values
Browse files Browse the repository at this point in the history
because returning Java primitive values causes ClassCastException for a
certain case.
  • Loading branch information
myui committed Jun 12, 2014
1 parent 160f66d commit 71a3415
Show file tree
Hide file tree
Showing 33 changed files with 268 additions and 238 deletions.
7 changes: 5 additions & 2 deletions src/main/hivemall/ensemble/MaxValueLabelUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
*/
package hivemall.ensemble;

import hivemall.utils.WritableUtils;

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.io.Text;

public class MaxValueLabelUDAF extends UDAF {

Expand Down Expand Up @@ -81,11 +84,11 @@ public boolean merge(PartialResult other) {
return true;
}

public String terminate() {
public Text terminate() {
if(partial == null) {
return null; // null to indicate that no values have been aggregated yet
}
return partial.label;
return WritableUtils.val(partial.label);
}
}
}
11 changes: 7 additions & 4 deletions src/main/hivemall/ensemble/bagging/VotedAvgUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
*/
package hivemall.ensemble.bagging;

import static hivemall.utils.WritableUtils.val;

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;

public class VotedAvgUDAF extends UDAF {

Expand Down Expand Up @@ -84,18 +87,18 @@ public boolean merge(PartialResult other) {
return true;
}

public Double terminate() {
public DoubleWritable terminate() {
if(partial == null) {
return null; // null to indicate that no values have been aggregated yet
}
if(partial.positiveCnt > partial.negativeCnt) {
return partial.positiveSum / partial.positiveCnt;
return val(partial.positiveSum / partial.positiveCnt);
} else {
if(partial.negativeCnt == 0) {
assert (partial.negativeSum == 0d) : partial.negativeSum;
return 0.d;
return val(0.d);
}
return partial.negativeSum / partial.negativeCnt;
return val(partial.negativeSum / partial.negativeCnt);
}
}
}
Expand Down
11 changes: 7 additions & 4 deletions src/main/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
*/
package hivemall.ensemble.bagging;

import static hivemall.utils.WritableUtils.val;

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;

public class WeightVotedAvgUDAF extends UDAF {

Expand Down Expand Up @@ -84,18 +87,18 @@ public boolean merge(PartialResult other) {
return true;
}

public Double terminate() {
public DoubleWritable terminate() {
if(partial == null) {
return null; // null to indicate that no values have been aggregated yet
}
if(partial.positiveSum > (-partial.negativeSum)) {
return partial.positiveSum / partial.positiveCnt;
return val(partial.positiveSum / partial.positiveCnt);
} else {
if(partial.negativeCnt == 0) {
assert (partial.negativeSum == 0d) : partial.negativeSum;
return 0.d;
return val(0.d);
}
return partial.negativeSum / partial.negativeCnt;
return val(partial.negativeSum / partial.negativeCnt);
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions src/main/hivemall/ftvec/AddBiasUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,31 @@
package hivemall.ftvec;

import hivemall.HivemallConstants;
import hivemall.utils.WritableUtils;

import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.Text;

public class AddBiasUDF extends UDF {

public List<String> evaluate(List<String> ftvec) {
public List<Text> evaluate(List<String> ftvec) {
String biasClause = Integer.toString(HivemallConstants.BIAS_CLAUSE_INT);
return evaluate(ftvec, biasClause);
}

public List<String> evaluate(List<String> ftvec, String biasClause) {
public List<Text> evaluate(List<String> ftvec, String biasClause) {
float biasValue = 1.f;
return evaluate(ftvec, biasClause, biasValue);
}

public List<String> evaluate(List<String> ftvec, String biasClause, float biasValue) {
public List<Text> evaluate(List<String> ftvec, String biasClause, float biasValue) {
int size = ftvec.size();
String[] newvec = new String[size + 1];
ftvec.toArray(newvec);
newvec[size] = biasClause + ":" + Float.toString(biasValue);
return Arrays.asList(newvec);
return WritableUtils.val(newvec);
}

}
17 changes: 9 additions & 8 deletions src/main/hivemall/ftvec/ConvertToDenseModelUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.io.FloatWritable;

public class ConvertToDenseModelUDAF extends UDAF {

public static class Evaluator implements UDAFEvaluator {

private List<Float> partial;
private List<FloatWritable> partial;

@Override
public void init() {
Expand All @@ -40,36 +41,36 @@ public void init() {

public boolean iterate(int feature, float weight, int nDims) {
if(partial == null) {
Float[] array = new Float[nDims];
FloatWritable[] array = new FloatWritable[nDims];
this.partial = Arrays.asList(array);
}
partial.set(feature, new Float(weight));
partial.set(feature, new FloatWritable(weight));
return true;
}

public List<Float> terminatePartial() {
public List<FloatWritable> terminatePartial() {
return partial;
}

public boolean merge(List<Float> other) {
public boolean merge(List<FloatWritable> other) {
if(other == null) {
return true;
}
if(partial == null) {
this.partial = new ArrayList<Float>(other);
this.partial = new ArrayList<FloatWritable>(other);
return true;
}
final int nDims = other.size();
for(int i = 0; i < nDims; i++) {
Float x = other.set(i, null);
FloatWritable x = other.set(i, null);
if(x != null) {
partial.set(i, x);
}
}
return true;
}

public List<Float> terminate() {
public List<FloatWritable> terminate() {
if(partial == null) {
return null; // null to indicate that no values have been aggregated yet
}
Expand Down
6 changes: 4 additions & 2 deletions src/main/hivemall/ftvec/SortByFeatureUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import java.util.TreeMap;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

public class SortByFeatureUDF extends UDF {

public Map<Integer, Float> evaluate(Map<Integer, Float> arg) {
Map<Integer, Float> ret = new TreeMap<Integer, Float>();
public Map<IntWritable, FloatWritable> evaluate(Map<IntWritable, FloatWritable> arg) {
Map<IntWritable, FloatWritable> ret = new TreeMap<IntWritable, FloatWritable>();
ret.putAll(arg);
return ret;
}
Expand Down
18 changes: 10 additions & 8 deletions src/main/hivemall/ftvec/hashing/ArrayHashValuesUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,47 @@
*/
package hivemall.ftvec.hashing;

import static hivemall.utils.WritableUtils.val;
import hivemall.utils.hashing.MurmurHash3;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.IntWritable;

public class ArrayHashValuesUDF extends UDF {

public List<Integer> evaluate(List<String> values) {
public List<IntWritable> evaluate(List<String> values) {
return evaluate(values, null, MurmurHash3.DEFAULT_NUM_FEATURES);
}

public List<Integer> evaluate(List<String> values, String prefix) {
public List<IntWritable> evaluate(List<String> values, String prefix) {
return evaluate(values, prefix, MurmurHash3.DEFAULT_NUM_FEATURES);
}

public List<Integer> evaluate(List<String> values, String prefix, boolean useIndexAsPrefix) {
public List<IntWritable> evaluate(List<String> values, String prefix, boolean useIndexAsPrefix) {
return evaluate(values, prefix, MurmurHash3.DEFAULT_NUM_FEATURES, useIndexAsPrefix);
}

public List<Integer> evaluate(List<String> values, String prefix, int numFeatures) {
public List<IntWritable> evaluate(List<String> values, String prefix, int numFeatures) {
return evaluate(values, prefix, numFeatures, false);
}

public List<Integer> evaluate(List<String> values, String prefix, int numFeatures, boolean useIndexAsPrefix) {
public List<IntWritable> evaluate(List<String> values, String prefix, int numFeatures, boolean useIndexAsPrefix) {
return hashValues(values, prefix, numFeatures, useIndexAsPrefix);
}

static List<Integer> hashValues(List<String> values, String prefix, int numFeatures, boolean useIndexAsPrefix) {
static List<IntWritable> hashValues(List<String> values, String prefix, int numFeatures, boolean useIndexAsPrefix) {
if(values == null) {
return null;
}
if(values.isEmpty()) {
return Collections.emptyList();
}
final int size = values.size();
final Integer[] ary = new Integer[size];
final IntWritable[] ary = new IntWritable[size];
for(int i = 0; i < size; i++) {
String v = values.get(i);
if(v == null) {
Expand All @@ -68,7 +70,7 @@ static List<Integer> hashValues(List<String> values, String prefix, int numFeatu
v = i + ':' + v;
}
String data = (prefix == null) ? v : (prefix + v);
ary[i] = MurmurHash3.murmurhash3(data, numFeatures);
ary[i] = val(MurmurHash3.murmurhash3(data, numFeatures));
}
}
return Arrays.asList(ary);
Expand Down
15 changes: 9 additions & 6 deletions src/main/hivemall/ftvec/hashing/ArrayPrefixedHashValuesUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,36 @@
*/
package hivemall.ftvec.hashing;

import static hivemall.utils.WritableUtils.val;
import hivemall.utils.hashing.MurmurHash3;

import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

public class ArrayPrefixedHashValuesUDF extends UDF {

public List<String> evaluate(List<String> values, String prefix) {
public List<Text> evaluate(List<String> values, String prefix) {
return evaluate(values, prefix, false);
}

public List<String> evaluate(List<String> values, String prefix, boolean useIndexAsPrefix) {
public List<Text> evaluate(List<String> values, String prefix, boolean useIndexAsPrefix) {
if(values == null) {
return null;
}
if(prefix == null) {
prefix = "";
}

List<Integer> hashValues = ArrayHashValuesUDF.hashValues(values, null, MurmurHash3.DEFAULT_NUM_FEATURES, useIndexAsPrefix);
List<IntWritable> hashValues = ArrayHashValuesUDF.hashValues(values, null, MurmurHash3.DEFAULT_NUM_FEATURES, useIndexAsPrefix);
final int len = hashValues.size();
final String[] stringValues = new String[len];
final Text[] stringValues = new Text[len];
for(int i = 0; i < len; i++) {
Integer v = hashValues.get(i);
stringValues[i] = prefix + v.toString();
IntWritable v = hashValues.get(i);
stringValues[i] = val(prefix + v.toString());
}
return Arrays.asList(stringValues);
}
Expand Down
18 changes: 10 additions & 8 deletions src/main/hivemall/ftvec/hashing/MurmurHash3UDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,54 @@
*/
package hivemall.ftvec.hashing;

import static hivemall.utils.WritableUtils.val;
import hivemall.utils.hashing.MurmurHash3;

import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.io.IntWritable;

public class MurmurHash3UDF extends UDF {

public int evaluate(String word) throws UDFArgumentException {
public IntWritable evaluate(String word) throws UDFArgumentException {
return evaluate(word, MurmurHash3.DEFAULT_NUM_FEATURES);
}

public int evaluate(String word, boolean rawValue) throws UDFArgumentException {
public IntWritable evaluate(String word, boolean rawValue) throws UDFArgumentException {
if(rawValue) {
if(word == null) {
throw new UDFArgumentException("argument must not be null");
}
return MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c);
return val(MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c));
} else {
return evaluate(word, MurmurHash3.DEFAULT_NUM_FEATURES);
}
}

public int evaluate(String word, int numFeatures) throws UDFArgumentException {
public IntWritable evaluate(String word, int numFeatures) throws UDFArgumentException {
if(word == null) {
throw new UDFArgumentException("argument must not be null");
}
int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
if(r < 0) {
r += numFeatures;
}
return r;
return val(r);
}

public int evaluate(List<String> words) throws UDFArgumentException {
public IntWritable evaluate(List<String> words) throws UDFArgumentException {
return evaluate(words, MurmurHash3.DEFAULT_NUM_FEATURES);
}

public int evaluate(List<String> words, int numFeatures) throws UDFArgumentException {
public IntWritable evaluate(List<String> words, int numFeatures) throws UDFArgumentException {
if(words == null) {
throw new UDFArgumentException("argument must not be null");
}
final int size = words.size();
if(size == 0) {
return 0;
return val(0);
}
final StringBuilder b = new StringBuilder();
b.append(words.get(0));
Expand Down
Loading

0 comments on commit 71a3415

Please sign in to comment.