Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge pull request #47 from myui/feature/amplifier

Feature/amplifier
  • Loading branch information...
commit 23b4f3b36efed4674b734b8657b089c30bb0b1a6 2 parents 17b0730 + 6098912
@myui authored
View
2  scripts/ddl/define-all.hive
@@ -129,7 +129,7 @@ drop temporary function amplify;
create temporary function amplify as 'hivemall.ftvec.AmplifierUDTF';
drop temporary function rand_amplify;
-create temporary function rand_amplify as 'hivemall.ftvec.RandAmplifierUDTF';
+create temporary function rand_amplify as 'hivemall.ftvec.RandomAmplifierUDTF';
drop temporary function conv2dense;
create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';
View
2  scripts/ddl/define-ftvec-udf.hive
@@ -43,7 +43,7 @@ drop temporary function amplify;
create temporary function amplify as 'hivemall.ftvec.AmplifierUDTF';
drop temporary function rand_amplify;
-create temporary function rand_amplify as 'hivemall.ftvec.RandAmplifierUDTF';
+create temporary function rand_amplify as 'hivemall.ftvec.RandomAmplifierUDTF';
drop temporary function conv2dense;
create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';
View
134 src/main/hivemall/common/RandomDropoutAmplifier.java
@@ -0,0 +1,134 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.common;
+
+import hivemall.utils.ArrayUtils;
+
+import java.util.Random;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+
+public class RandomDropoutAmplifier<T> {
+
+ private final int numBuffers;
+ private final int lastPos;
+ private final int xtimes;
+
+ private final AgedObject<T>[][] slots;
+ private int position;
+
+ private final Random[] randoms;
+
+ private DropoutListener<T> listener = null;
+
+ @SuppressWarnings("unchecked")
+ public RandomDropoutAmplifier(int numBuffers, int xtimes) {
+ if(numBuffers < 1) {
+ throw new IllegalArgumentException("numBuffers must be greater than 0: " + numBuffers);
+ }
+ if(xtimes < 1) {
+ throw new IllegalArgumentException("xtime must be greater than 0: " + xtimes);
+ }
+ this.numBuffers = numBuffers;
+ this.lastPos = numBuffers - 1;
+ this.xtimes = xtimes;
+ this.slots = new AgedObject[xtimes][numBuffers];
+ this.position = 0;
+ this.randoms = new Random[xtimes];
+ for(int i = 0; i < xtimes; i++) {
+ randoms[i] = new Random();
+ }
+ }
+
+ public void setDropoutListener(DropoutListener<T> listener) {
+ this.listener = listener;
+ }
+
+ public void add(T storedObj) throws HiveException {
+ if(position < numBuffers) {
+ for(int x = 0; x < xtimes; x++) {
+ slots[x][position] = new AgedObject<T>(storedObj);
+ }
+ position++;
+ if(position == numBuffers) {
+ for(int x = 0; x < xtimes; x++) {
+ ArrayUtils.shuffle(slots[x], randoms[x]);
+ }
+ }
+ } else {
+ for(int x = 0; x < xtimes; x++) {
+ AgedObject<T>[] slot = slots[x];
+ Random rnd = randoms[x];
+ int rindex1 = rnd.nextInt(lastPos);
+ int rindex2 = rnd.nextInt(lastPos);
+ AgedObject<T> replaced1 = slot[rindex1];
+ AgedObject<T> replaced2 = slot[rindex2];
+ assert (replaced1 != null);
+ assert (replaced2 != null);
+ if(replaced1.timestamp >= replaced2.timestamp) {// bias to hold old entry
+ dropout(replaced1.object);
+ slot[rindex1] = new AgedObject<T>(storedObj);
+ } else {
+ dropout(replaced2.object);
+ slot[rindex2] = new AgedObject<T>(storedObj);
+ }
+ }
+ }
+ }
+
+ public void sweepAll() throws HiveException {
+ for(int i = 0; i < numBuffers; i++) {
+ for(int x = 0; x < xtimes; x++) {
+ AgedObject<T>[] slot = slots[x];
+ AgedObject<T> sweepedObj = slot[i];
+ if(sweepedObj != null) {
+ dropout(sweepedObj.object);
+ slot[i] = null;
+ }
+ }
+ }
+ }
+
+ protected void dropout(T droppped) throws HiveException {
+ if(droppped == null) {
+ throw new IllegalStateException("Illegal condition that dropped object is null");
+ }
+ if(listener != null) {
+ listener.onDrop(droppped);
+ }
+ }
+
+ private static final class AgedObject<T> {
+
+ private final T object;
+ private final long timestamp;
+
+ AgedObject(T obj) {
+ this.object = obj;
+ this.timestamp = System.nanoTime();
+ }
+ }
+
+ public interface DropoutListener<T> {
+ void onDrop(T droppped) throws HiveException;
+ }
+
+}
View
58 src/main/hivemall/ftvec/RandAmplifierUDTF.java → src/main/hivemall/ftvec/RandomAmplifierUDTF.java
@@ -21,7 +21,8 @@
package hivemall.ftvec;
import hivemall.common.HivemallConstants;
-import hivemall.utils.ArrayUtils;
+import hivemall.common.RandomDropoutAmplifier;
+import hivemall.common.RandomDropoutAmplifier.DropoutListener;
import java.util.ArrayList;
@@ -35,16 +36,12 @@
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;
-public class RandAmplifierUDTF extends GenericUDTF {
-
- private int xtimes;
- private int numBuffers;
-
- private Object[][] _forwardBuffers;
- private int _position;
+public class RandomAmplifierUDTF extends GenericUDTF implements DropoutListener<Object[]> {
private transient ObjectInspector[] retrunOIs;
+ private transient RandomDropoutAmplifier<Object[]> amplifier;
+
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
@@ -59,7 +56,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
throw new UDFArgumentException("WritableConstantIntObjectInspector is expected for the first argument: "
+ argOIs[0].getClass().getSimpleName());
}
- this.xtimes = ((WritableConstantIntObjectInspector) argOIs[0]).getWritableConstantValue().get();
+ int xtimes = ((WritableConstantIntObjectInspector) argOIs[0]).getWritableConstantValue().get();
if(!(xtimes >= 1)) {
throw new UDFArgumentException("Illegal xtimes value: " + xtimes);
}
@@ -72,14 +69,13 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
throw new UDFArgumentException("WritableConstantIntObjectInspector is expected for the second argument: "
+ argOIs[1].getClass().getSimpleName());
}
- this.numBuffers = ((WritableConstantIntObjectInspector) argOIs[1]).getWritableConstantValue().get();
+ int numBuffers = ((WritableConstantIntObjectInspector) argOIs[1]).getWritableConstantValue().get();
if(numBuffers < 2) {
throw new UDFArgumentException("num_buffers must be greater than 2: " + numBuffers);
}
- int numForwardObjs = numArgs - 2;
- this._forwardBuffers = new Object[numBuffers][numForwardObjs];
- this._position = 0;
+ this.amplifier = new RandomDropoutAmplifier<Object[]>(numBuffers, xtimes);
+ amplifier.setDropoutListener(this);
final ArrayList<String> fieldNames = new ArrayList<String>();
final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
@@ -96,38 +92,24 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
@Override
public void process(Object[] args) throws HiveException {
- final Object[][] forwardBuffers = _forwardBuffers;
- for(int x = 0; x < xtimes; x++) { // amplify x times
- final Object[] forwardObjs = forwardBuffers[_position];
- for(int i = 2; i < args.length; i++) {// copy a row
- Object arg = args[i];
- ObjectInspector returnOI = retrunOIs[i];
- forwardObjs[i - 2] = ObjectInspectorUtils.copyToStandardObject(arg, returnOI);
- }
- _position++;
- if(_position == numBuffers) {
- shuffleAndForward(forwardBuffers, _position);
- this._position = 0;
- }
+ final Object[] row = new Object[args.length - 2];
+ for(int i = 2; i < args.length; i++) {
+ Object arg = args[i];
+ ObjectInspector returnOI = retrunOIs[i];
+ row[i - 2] = ObjectInspectorUtils.copyToStandardObject(arg, returnOI);
}
+ amplifier.add(row);
}
@Override
public void close() throws HiveException {
- if(_position > 0) {
- shuffleAndForward(_forwardBuffers, _position);
- }
- this._forwardBuffers = null;
- this._position = 0;
+ amplifier.sweepAll();
+ this.amplifier = null;
}
- private void shuffleAndForward(final Object[][] forwardBuffers, final int numForwards)
- throws HiveException {
- ArrayUtils.shuffle(forwardBuffers, numForwards);
- for(int i = 0; i < numForwards; i++) {
- Object[] forwardObj = forwardBuffers[i];
- forward(forwardObj);
- }
+ @Override
+ public void onDrop(Object[] row) throws HiveException {
+ forward(row);
}
}
View
74 src/test/hivemall/common/RandomDropoutAmplifierTest.java
@@ -0,0 +1,74 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.common;
+
+import hivemall.common.RandomDropoutAmplifier.DropoutListener;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+import junit.framework.Assert;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.junit.Test;
+
+public class RandomDropoutAmplifierTest {
+
+ @Test
+ public void test() throws HiveException {
+ int size = 10000;
+ Integer[] numlist = new Integer[size];
+ for(int i = 0; i < size; i++) {
+ numlist[i] = i;
+ }
+
+ int xtimes = 3;
+ RandomDropoutAmplifier<Integer> amplifier = new RandomDropoutAmplifier<Integer>(1000, xtimes);
+ DropoutCollector collector = new DropoutCollector();
+ amplifier.setDropoutListener(collector);
+ for(Integer obj : numlist) {
+ amplifier.add(obj);
+ }
+ amplifier.sweepAll();
+
+ Assert.assertEquals(size * xtimes, collector.count);
+ Assert.assertEquals(size, collector.numset.size());
+
+ Set<Integer> expectedSet = new HashSet<Integer>(Arrays.asList(numlist));
+ Assert.assertEquals(expectedSet, collector.numset);
+ }
+
+ private class DropoutCollector implements DropoutListener<Integer> {
+
+ private int count = 0;
+ private final Set<Integer> numset = new HashSet<Integer>(10000);
+
+ @Override
+ public void onDrop(Integer droppped) {
+ //System.out.println(droppped);
+ numset.add(droppped);
+ count++;
+ }
+
+ }
+
+}
View
BIN  target/hivemall.jar
Binary file not shown
Please sign in to comment.
Something went wrong with that request. Please try again.