From c78b75b0683a40c6fb55f2ae8cb23ce636b5b517 Mon Sep 17 00:00:00 2001 From: myui Date: Fri, 8 Apr 2016 14:43:54 +0900 Subject: [PATCH] Added bitset_collect UDAF --- .../tools/bits/BitsetCollectUDAF.java | 145 ++++++++++++++++++ resources/ddl/define-all-as-permanent.hive | 7 + resources/ddl/define-all.hive | 7 + 3 files changed, 159 insertions(+) create mode 100644 core/src/main/java/hivemall/tools/bits/BitsetCollectUDAF.java diff --git a/core/src/main/java/hivemall/tools/bits/BitsetCollectUDAF.java b/core/src/main/java/hivemall/tools/bits/BitsetCollectUDAF.java new file mode 100644 index 00000000..7b60b2d7 --- /dev/null +++ b/core/src/main/java/hivemall/tools/bits/BitsetCollectUDAF.java @@ -0,0 +1,145 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.tools.bits; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; + +import java.util.BitSet; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; + +@Description(name = "bitset_collect", + value = "_FUNC_(int|long x) - Retrurns a bitset in array") +public final class BitsetCollectUDAF extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 1) { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "Exactly one argument is expected"); + } + if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, "_FUNC_(int|long x) is expected: " + typeInfo[0]); + } + return new Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + private PrimitiveObjectInspector inputOI; + private StandardListObjectInspector internalMergeOI; + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException { + assert (argOIs.length == 1); + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.inputOI = HiveUtils.asLongCompatibleOI(argOIs[0]); + } else {// from partial aggregation + this.internalMergeOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + } else {// terminate + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + } + return outputOI; + } + + static class ArrayAggregationBuffer implements AggregationBuffer { + BitSet bitset; + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + ArrayAggregationBuffer ret = new ArrayAggregationBuffer(); + reset(ret); + return ret; + } + + @Override + public void reset(AggregationBuffer aggr) throws HiveException { + ((ArrayAggregationBuffer) aggr).bitset = new BitSet(); + } + + @Override + public void iterate(AggregationBuffer aggr, Object[] parameters) throws HiveException { + assert (parameters.length == 1); + Object arg = parameters[0]; + if (arg != null) { + int index = PrimitiveObjectInspectorUtils.getInt(arg, inputOI); + if (index < 0) { + throw new UDFArgumentException("Specified index SHOULD NOT be negative: " + + index); + } + ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr; + agg.bitset.set(index); + } + } + + @Override + public List terminatePartial(AggregationBuffer aggr) throws HiveException { + ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr; + long[] array = agg.bitset.toLongArray(); + if (agg.bitset == null || agg.bitset.isEmpty()) { + return null; + } + return WritableUtils.toWritableList(array); + } + + @Override + public void merge(AggregationBuffer aggr, Object other) throws HiveException { + if (other == null) { + return; + } + ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr; + long[] longs = HiveUtils.asLongArray(other, internalMergeOI, + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + BitSet otherBitset = BitSet.valueOf(longs); + agg.bitset.or(otherBitset); + } + + @Override + public List terminate(AggregationBuffer aggr) throws HiveException { + ArrayAggregationBuffer agg = (ArrayAggregationBuffer) aggr; + long[] longs = agg.bitset.toLongArray(); + return WritableUtils.toWritableList(longs); + } + } +} diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index f9792ae4..0156e570 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -371,6 +371,13 @@ CREATE FUNCTION array_sum as 'hivemall.tools.array.ArraySumUDAF' USING JAR '${hi DROP FUNCTION IF EXISTS to_string_array; CREATE FUNCTION to_string_array as 'hivemall.tools.array.ToStringArrayUDF' USING JAR '${hivemall_jar}'; +----------------------------- +-- bit operation functions -- +----------------------------- + +DROP FUNCTION IF EXISTS bitset_collect; +CREATE FUNCTION bitset_collect as 'hivemall.tools.bits.BitsetCollectUDAF' USING JAR '${hivemall_jar}'; + --------------------------- -- compression functions -- --------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 7328714e..a72f7d20 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -367,6 +367,13 @@ create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF'; drop temporary function to_string_array; create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF'; +----------------------------- +-- bit operation functions -- +----------------------------- + +drop temporary function bitset_collect; +create temporary function bitset_collect as 'hivemall.tools.bits.BitsetCollectUDAF'; + --------------------------- -- compression functions -- ---------------------------