Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fix option handle for aggressive parameter C. Add test methods for PA1

  • Loading branch information...
commit 095d9395f06ddf56a236192d1fdba07553cf527b 1 parent ccf8928
@smly smly authored
View
9 src/main/hivemall/classifier/PassiveAggressiveUDTF.java
@@ -25,6 +25,7 @@
import java.util.List;
+import org.apache.commons.cli.Options;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -68,6 +69,14 @@ protected float eta(float loss, PredictionResult margin) {
protected float c;
@Override
+ protected Options getOptions() {
+ System.out.println("OverrideOptions");
+ Options opts = super.getOptions();
+ opts.addOption("c", "cparam", true, "Aggressiveness parameter C [default 1.0]");
+ return opts;
+ }
+
+ @Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
final CommandLine cl = super.processOptions(argOIs);
View
56 src/test/hivemall/classifier/PassiveAggressiveUDTFTest.java
@@ -28,22 +28,68 @@
import static org.junit.Assert.assertEquals;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.junit.Test;
public class PassiveAggressiveUDTFTest {
@Test
+ public void testPA1WithoutParameter() throws UDFArgumentException {
+ PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
+ ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+ ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
+
+ /* define aggressive parameter */
+ udtf.initialize(new ObjectInspector[]{intListOI, intOI});
+
+ /* train weights by List<Object> */
+ List<Integer> list = new ArrayList<Integer>();
+ list.add(1);
+ list.add(2);
+ list.add(3);
+ List<?> features1 = (List<?>) intListOI.getList(list);
+ udtf.train(features1, 1);
+
+ /* check weights */
+ assertEquals(0.3333333f, udtf.weights.get(1).get(), 1e-5f);
+ assertEquals(0.3333333f, udtf.weights.get(2).get(), 1e-5f);
+ assertEquals(0.3333333f, udtf.weights.get(3).get(), 1e-5f);
+ }
+
+ @Test
+ public void testPA1WithParameter() throws UDFArgumentException {
+ PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
+ ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+ ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
+
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ new String("-c 0.1")
+ );
+ /* define aggressive parameter */
+ udtf.initialize(new ObjectInspector[]{intListOI, intOI, param});
+
+ /* train weights by List<Object> */
+ List<Integer> list = new ArrayList<Integer>();
+ list.add(1);
+ list.add(2);
+ list.add(3);
+ List<?> features1 = (List<?>) intListOI.getList(list);
+ udtf.train(features1, 1);
+
+ /* check weights */
+ assertEquals(0.1000000f, udtf.weights.get(1).get(), 1e-5f);
+ assertEquals(0.1000000f, udtf.weights.get(2).get(), 1e-5f);
+ assertEquals(0.1000000f, udtf.weights.get(3).get(), 1e-5f);
+ }
+
+ @Test
public void testInitialize() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
Please sign in to comment.
Something went wrong with that request. Please try again.