From 88af7b4b0e8424d3f9b983724fc009fcfb5a64b8 Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Fri, 4 Jul 2014 19:48:22 +0900 Subject: [PATCH] Added extract_weight(string featureVectors)::weights UDF --- scripts/ddl/define-all.hive | 3 + scripts/ddl/define-ftvec-udf.hive | 2 + src/main/hivemall/ftvec/ExtractWeightUDF.java | 65 +++++++++++++++++++ 3 files changed, 70 insertions(+) create mode 100644 src/main/hivemall/ftvec/ExtractWeightUDF.java diff --git a/scripts/ddl/define-all.hive b/scripts/ddl/define-all.hive index 4d7e714d..faab93cf 100644 --- a/scripts/ddl/define-all.hive +++ b/scripts/ddl/define-all.hive @@ -176,6 +176,9 @@ create temporary function sortByFeature as 'hivemall.ftvec.SortByFeatureUDF'; drop temporary function extract_feature; create temporary function extract_feature as 'hivemall.ftvec.ExtractFeatureUDF'; +drop temporary function extract_weight; +create temporary function extract_weight as 'hivemall.ftvec.ExtractWeightUDF'; + -------------------------- -- Regression functions -- -------------------------- diff --git a/scripts/ddl/define-ftvec-udf.hive b/scripts/ddl/define-ftvec-udf.hive index f556f965..f6b33557 100644 --- a/scripts/ddl/define-ftvec-udf.hive +++ b/scripts/ddl/define-ftvec-udf.hive @@ -57,3 +57,5 @@ create temporary function sortByFeature as 'hivemall.ftvec.SortByFeatureUDF'; drop temporary function extract_feature; create temporary function extract_feature as 'hivemall.ftvec.ExtractFeatureUDF'; +drop temporary function extract_weight; +create temporary function extract_weight as 'hivemall.ftvec.ExtractWeightUDF'; diff --git a/src/main/hivemall/ftvec/ExtractWeightUDF.java b/src/main/hivemall/ftvec/ExtractWeightUDF.java new file mode 100644 index 00000000..97e807f4 --- /dev/null +++ b/src/main/hivemall/ftvec/ExtractWeightUDF.java @@ -0,0 +1,65 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.ftvec; + +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.io.FloatWritable; + +@Description(name = "extract_weight", value = "_FUNC_(feature_vector in array) - Returns the weights of features in array") +@UDFType(deterministic = true, stateful = false) +public class ExtractWeightUDF extends UDF { + + public FloatWritable evaluate(String featureVector) throws UDFArgumentException { + return extractWeights(featureVector); + } + + public List evaluate(List featureVectors) throws UDFArgumentException { + if(featureVectors == null) { + return null; + } + final int size = featureVectors.size(); + final FloatWritable[] output = new FloatWritable[size]; + for(int i = 0; i < size; i++) { + String ftvec = featureVectors.get(i); + output[i] = extractWeights(ftvec); + } + return Arrays.asList(output); + } + + private static FloatWritable extractWeights(String ftvec) throws UDFArgumentException { + if(ftvec == null) { + return null; + } + String[] splits = ftvec.split(":"); + if(splits.length != 2) { + throw new UDFArgumentException("Unexpected feature vector representation: " + ftvec); + } + float f = Float.valueOf(splits[1]); + return new FloatWritable(f); + } + +}