Skip to content
This repository
Browse code

Added rand_amplify()

  • Loading branch information...
commit 538d86688b90ec13fdbf565e767537041364f1e6 1 parent 01f9b40
Makoto YUI authored November 16, 2013
3  scripts/ddl/define-all.hive
@@ -128,6 +128,9 @@ create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
128 128
 drop temporary function amplify;
129 129
 create temporary function amplify as 'hivemall.ftvec.AmplifierUDTF';
130 130
 
  131
+drop temporary function rand_amplify;
  132
+create temporary function rand_amplify as 'hivemall.ftvec.RandAmplifierUDTF';
  133
+
131 134
 drop temporary function conv2dense;
132 135
 create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';
133 136
 
3  scripts/ddl/define-ftvec-udf.hive
@@ -42,6 +42,9 @@ create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
42 42
 drop temporary function amplify;
43 43
 create temporary function amplify as 'hivemall.ftvec.AmplifierUDTF';
44 44
 
  45
+drop temporary function rand_amplify;
  46
+create temporary function rand_amplify as 'hivemall.ftvec.RandAmplifierUDTF';
  47
+
45 48
 drop temporary function conv2dense;
46 49
 create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';
47 50
 
122  src/main/hivemall/ftvec/RandAmplifierUDTF.java
... ...
@@ -0,0 +1,122 @@
  1
+/*
  2
+ * Hivemall: Hive scalable Machine Learning Library
  3
+ *
  4
+ * Copyright (C) 2013
  5
+ *   National Institute of Advanced Industrial Science and Technology (AIST)
  6
+ *   Registration Number: H25PRO-1520
  7
+ *
  8
+ * This library is free software; you can redistribute it and/or
  9
+ * modify it under the terms of the GNU Lesser General Public
  10
+ * License as published by the Free Software Foundation.
  11
+ *
  12
+ * This library is distributed in the hope that it will be useful,
  13
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  15
+ * Lesser General Public License for more details.
  16
+ *
  17
+ * You should have received a copy of the GNU Lesser General Public
  18
+ * License along with this library; if not, write to the Free Software
  19
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
  20
+ */
  21
+package hivemall.ftvec;
  22
+
  23
+import hivemall.common.HivemallConstants;
  24
+import hivemall.utils.ArrayUtils;
  25
+
  26
+import java.util.ArrayList;
  27
+
  28
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
  29
+import org.apache.hadoop.hive.ql.metadata.HiveException;
  30
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
  31
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
  32
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
  33
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
  34
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;
  35
+
  36
+public class RandAmplifierUDTF extends GenericUDTF {
  37
+
  38
+    private int xtimes;
  39
+    private int numBuffers;
  40
+
  41
+    private Object[][] _forwardBuffers;
  42
+    private int _position;
  43
+
  44
+    @Override
  45
+    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
  46
+        if(!(argOIs.length >= 3)) {
  47
+            throw new UDFArgumentException("rand_amplify(int xtimes, int num_buffers, *) takes at least three arguments");
  48
+        }
  49
+        // xtimes
  50
+        if(argOIs[0].getTypeName() != HivemallConstants.INT_TYPE_NAME) {
  51
+            throw new UDFArgumentException("First argument must be int: " + argOIs[0].getTypeName());
  52
+        }
  53
+        if(!(argOIs[0] instanceof WritableConstantIntObjectInspector)) {
  54
+            throw new UDFArgumentException("WritableConstantIntObjectInspector is expected for the first argument: "
  55
+                    + argOIs[0].getClass().getSimpleName());
  56
+        }
  57
+        this.xtimes = ((WritableConstantIntObjectInspector) argOIs[0]).getWritableConstantValue().get();
  58
+        if(!(xtimes >= 1)) {
  59
+            throw new UDFArgumentException("Illegal xtimes value: " + xtimes);
  60
+        }
  61
+        // num_buffers
  62
+        if(argOIs[1].getTypeName() != HivemallConstants.INT_TYPE_NAME) {
  63
+            throw new UDFArgumentException("Second argument must be int: "
  64
+                    + argOIs[1].getTypeName());
  65
+        }
  66
+        if(!(argOIs[1] instanceof WritableConstantIntObjectInspector)) {
  67
+            throw new UDFArgumentException("WritableConstantIntObjectInspector is expected for the second argument: "
  68
+                    + argOIs[1].getClass().getSimpleName());
  69
+        }
  70
+        this.numBuffers = ((WritableConstantIntObjectInspector) argOIs[1]).getWritableConstantValue().get();
  71
+        if(numBuffers < 2) {
  72
+            throw new UDFArgumentException("num_buffers must be greater than 2: " + numBuffers);
  73
+        }
  74
+        this._forwardBuffers = new Object[numBuffers][argOIs.length - 2];
  75
+        this._position = 0;
  76
+
  77
+        ArrayList<String> fieldNames = new ArrayList<String>();
  78
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
  79
+
  80
+        for(int i = 1; i < argOIs.length; i++) {
  81
+            fieldNames.add("c" + i);
  82
+            fieldOIs.add(argOIs[i]);
  83
+        }
  84
+
  85
+        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
  86
+    }
  87
+
  88
+    @Override
  89
+    public void process(Object[] args) throws HiveException {
  90
+        final Object[][] forwardBuffers = _forwardBuffers;
  91
+        for(int x = 0; x < xtimes; x++) { // amplify x times              
  92
+            final Object[] forwardObjs = forwardBuffers[_position];
  93
+            for(int i = 2; i < args.length; i++) {// copy a row
  94
+                forwardObjs[i - 2] = args[i];
  95
+            }
  96
+            _position++;
  97
+            if(_position == numBuffers) {
  98
+                shuffleAndForward(forwardBuffers, _position);
  99
+                this._position = 0;
  100
+            }
  101
+        }
  102
+    }
  103
+
  104
+    @Override
  105
+    public void close() throws HiveException {
  106
+        if(_position > 0) {
  107
+            shuffleAndForward(_forwardBuffers, _position);
  108
+        }
  109
+        this._forwardBuffers = null;
  110
+        this._position = 0;
  111
+    }
  112
+
  113
+    private void shuffleAndForward(final Object[][] forwardBuffers, final int numForwards)
  114
+            throws HiveException {
  115
+        ArrayUtils.shuffle(forwardBuffers, numForwards);
  116
+        for(int i = 0; i < numForwards; i++) {
  117
+            Object[] forwardObj = forwardBuffers[i];
  118
+            forward(forwardObj);
  119
+        }
  120
+    }
  121
+
  122
+}
BIN  target/hivemall.jar
Binary file not shown

0 notes on commit 538d866

Please sign in to comment.
Something went wrong with that request. Please try again.