Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Fixed a bug in sigmoid(null). PLT-4718
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Oct 26, 2015
1 parent cd12905 commit ffe213a
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 3 deletions.
3 changes: 2 additions & 1 deletion scripts/ddl/define-all-as-permanent.hive
Expand Up @@ -454,7 +454,8 @@ CREATE FUNCTION to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap' USING JA
---------------------

DROP FUNCTION IF EXISTS sigmoid;
CREATE FUNCTION sigmoid as 'hivemall.tools.math.SigmodUDF' USING JAR '${hivemall_jar}';
CREATE FUNCTION sigmoid as 'hivemall.tools.math.SigmoidGenericUDF' USING JAR '${hivemall_jar}';
-- CREATE FUNCTION sigmoid as 'hivemall.tools.math.SigmoidUDF' USING JAR '${hivemall_jar}';

----------------------
-- mapred functions --
Expand Down
3 changes: 2 additions & 1 deletion scripts/ddl/define-all.hive
Expand Up @@ -450,7 +450,8 @@ create temporary function to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap
---------------------

drop temporary function sigmoid;
create temporary function sigmoid as 'hivemall.tools.math.SigmodUDF';
create temporary function sigmoid as 'hivemall.tools.math.SigmoidGenericUDF';
-- create temporary function sigmoid as 'hivemall.tools.math.SigmoidUDF';

----------------------
-- mapred functions --
Expand Down
76 changes: 76 additions & 0 deletions src/main/java/hivemall/tools/math/SigmoidGenericUDF.java
@@ -0,0 +1,76 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.tools.math;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;

import java.util.Arrays;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;

@Description(name = "sigmoid", value = "_FUNC_(x) - Returns 1.0 / (1.0 + exp(-x))")
@UDFType(deterministic = true, stateful = false)
public final class SigmoidGenericUDF extends GenericUDF {

private PrimitiveObjectInspector argOI;

@Override
public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
if(argOIs.length != 1) {
throw new UDFArgumentException("_FUNC_ takes 1 argument");
}
this.argOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]);
return PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
}

@Nullable
@Override
public FloatWritable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
assert (arguments.length == 1) : "sigmoid takes 1 argument: " + arguments.length;
DeferredObject arg0 = arguments[0];
assert (arg0 != null);
Object obj0 = arg0.get();
if(obj0 == null) {
return null;
}
double x = PrimitiveObjectInspectorUtils.getDouble(obj0, argOI);
float v = (float) MathUtils.sigmoid(x);
return new FloatWritable(v);
}

@Override
public String getDisplayString(String[] children) {
return "sigmoid(" + Arrays.toString(children) + ')';
}

}
Expand Up @@ -25,9 +25,10 @@
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.io.FloatWritable;

@Deprecated
@Description(name = "sigmoid", value = "_FUNC_(x) - Returns 1.0 / (1.0 + exp(-x))")
@UDFType(deterministic = true, stateful = false)
public final class SigmodUDF extends UDF {
public final class SigmoidUDF extends UDF {

public FloatWritable evaluate(float x) {
return val(1.0f / (1.0f + (float) Math.exp(-x)));
Expand Down
Binary file modified target/hivemall-fat.jar
Binary file not shown.
Binary file modified target/hivemall-with-dependencies.jar
Binary file not shown.
Binary file modified target/hivemall.jar
Binary file not shown.

0 comments on commit ffe213a

Please sign in to comment.