Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Added bitset_collect UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Apr 8, 2016
1 parent 1613902 commit c78b75b
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
145 changes: 145 additions & 0 deletions core/src/main/java/hivemall/tools/bits/BitsetCollectUDAF.java
@@ -0,0 +1,145 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.tools.bits;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;

import java.util.BitSet;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name = "bitset_collect",
value = "_FUNC_(int|long x) - Retrurns a bitset in array<long>")
public final class BitsetCollectUDAF extends AbstractGenericUDAFResolver {

@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
if (typeInfo.length != 1) {
throw new UDFArgumentTypeException(typeInfo.length - 1,
"Exactly one argument is expected");
}
if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
throw new UDFArgumentTypeException(0, "_FUNC_(int|long x) is expected: " + typeInfo[0]);
}
return new Evaluator();
}

public static class Evaluator extends GenericUDAFEvaluator {
private PrimitiveObjectInspector inputOI;
private StandardListObjectInspector internalMergeOI;

@Override
public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException {
assert (argOIs.length == 1);
super.init(mode, argOIs);

// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.inputOI = HiveUtils.asLongCompatibleOI(argOIs[0]);
} else {// from partial aggregation
this.internalMergeOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
}

// initialize output
final ObjectInspector outputOI;
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
outputOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
} else {// terminate
outputOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
}
return outputOI;
}

static class ArrayAggregationBuffer implements AggregationBuffer {
BitSet bitset;
}

@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
ArrayAggregationBuffer ret = new ArrayAggregationBuffer();
reset(ret);
return ret;
}

@Override
public void reset(AggregationBuffer aggr) throws HiveException {
((ArrayAggregationBuffer) aggr).bitset = new BitSet();
}

@Override
public void iterate(AggregationBuffer aggr, Object[] parameters) throws HiveException {
assert (parameters.length == 1);
Object arg = parameters[0];
if (arg != null) {
int index = PrimitiveObjectInspectorUtils.getInt(arg, inputOI);
if (index < 0) {
throw new UDFArgumentException("Specified index SHOULD NOT be negative: "
+ index);
}
ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr;
agg.bitset.set(index);
}
}

@Override
public List<LongWritable> terminatePartial(AggregationBuffer aggr) throws HiveException {
ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr;
long[] array = agg.bitset.toLongArray();
if (agg.bitset == null || agg.bitset.isEmpty()) {
return null;
}
return WritableUtils.toWritableList(array);
}

@Override
public void merge(AggregationBuffer aggr, Object other) throws HiveException {
if (other == null) {
return;
}
ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr;
long[] longs = HiveUtils.asLongArray(other, internalMergeOI,
PrimitiveObjectInspectorFactory.writableLongObjectInspector);
BitSet otherBitset = BitSet.valueOf(longs);
agg.bitset.or(otherBitset);
}

@Override
public List<LongWritable> terminate(AggregationBuffer aggr) throws HiveException {
ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr;
long[] longs = agg.bitset.toLongArray();
return WritableUtils.toWritableList(longs);
}
}
}
7 changes: 7 additions & 0 deletions resources/ddl/define-all-as-permanent.hive
Expand Up @@ -371,6 +371,13 @@ CREATE FUNCTION array_sum as 'hivemall.tools.array.ArraySumUDAF' USING JAR '${hi
DROP FUNCTION IF EXISTS to_string_array;
CREATE FUNCTION to_string_array as 'hivemall.tools.array.ToStringArrayUDF' USING JAR '${hivemall_jar}';

-----------------------------
-- bit operation functions --
-----------------------------

DROP FUNCTION IF EXISTS bitset_collect;
CREATE FUNCTION bitset_collect as 'hivemall.tools.bits.BitsetCollectUDAF' USING JAR '${hivemall_jar}';

---------------------------
-- compression functions --
---------------------------
Expand Down
7 changes: 7 additions & 0 deletions resources/ddl/define-all.hive
Expand Up @@ -367,6 +367,13 @@ create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
drop temporary function to_string_array;
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';

-----------------------------
-- bit operation functions --
-----------------------------

drop temporary function bitset_collect;
create temporary function bitset_collect as 'hivemall.tools.bits.BitsetCollectUDAF';

---------------------------
-- compression functions --
---------------------------
Expand Down

0 comments on commit c78b75b

Please sign in to comment.