Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Added euclid_similarity UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jul 21, 2015
1 parent d860107 commit 10d6e23
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 10 deletions.
3 changes: 3 additions & 0 deletions scripts/ddl/define-all-as-permanent.hive
Expand Up @@ -99,6 +99,9 @@ CREATE FUNCTION jaccard_similarity as 'hivemall.knn.similarity.JaccardIndexUDF'
DROP FUNCTION IF EXISTS angular_similarity;
CREATE FUNCTION angular_similarity as 'hivemall.knn.similarity.AngularSimilarityUDF' USING JAR '${hivemall_jar}';

DROP FUNCTION IF EXISTS euclid_similarity;
CREATE FUNCTION euclid_similarity as 'hivemall.knn.similarity.EuclidSimilarity' USING JAR '${hivemall_jar}';

------------------------
-- distance functions --
------------------------
Expand Down
3 changes: 3 additions & 0 deletions scripts/ddl/define-all-excluding-macro.hive
Expand Up @@ -95,6 +95,9 @@ create temporary function jaccard_similarity as 'hivemall.knn.similarity.Jaccard
drop temporary function angular_similarity;
create temporary function angular_similarity as 'hivemall.knn.similarity.AngularSimilarityUDF';

drop temporary function euclid_similarity;
create temporary function euclid_similarity as 'hivemall.knn.similarity.EuclidSimilarity';

------------------------
-- distance functions --
------------------------
Expand Down
3 changes: 3 additions & 0 deletions scripts/ddl/define-all.hive
Expand Up @@ -95,6 +95,9 @@ create temporary function jaccard_similarity as 'hivemall.knn.similarity.Jaccard
drop temporary function angular_similarity;
create temporary function angular_similarity as 'hivemall.knn.similarity.AngularSimilarityUDF';

drop temporary function euclid_similarity;
create temporary function euclid_similarity as 'hivemall.knn.similarity.EuclidSimilarity';

------------------------
-- distance functions --
------------------------
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/hivemall/knn/distance/EuclidDistanceUDF.java
Expand Up @@ -57,10 +57,11 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx
public FloatWritable evaluate(DeferredObject[] arguments) throws HiveException {
List<String> ftvec1 = HiveUtils.asStringList(arguments[0], arg0ListOI);
List<String> ftvec2 = HiveUtils.asStringList(arguments[1], arg1ListOI);
return evaluate(ftvec1, ftvec2);
float d = (float) euclidDistance(ftvec1, ftvec2);
return new FloatWritable(d);
}

public FloatWritable evaluate(final List<String> ftvec1, final List<String> ftvec2) {
public static double euclidDistance(final List<String> ftvec1, final List<String> ftvec2) {
final FeatureValue probe = new FeatureValue();
final Map<String, Float> map = new HashMap<String, Float>(ftvec1.size() * 2 + 1);
for(String ft : ftvec1) {
Expand Down Expand Up @@ -93,7 +94,7 @@ public FloatWritable evaluate(final List<String> ftvec1, final List<String> ftve
float v1f = e.getValue();
d += (v1f * v1f);
}
return new FloatWritable((float) Math.sqrt(d));
return Math.sqrt(d);
}

@Override
Expand Down
69 changes: 69 additions & 0 deletions src/main/java/hivemall/knn/similarity/EuclidSimilarity.java
@@ -0,0 +1,69 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.knn.similarity;

import hivemall.knn.distance.EuclidDistanceUDF;
import hivemall.utils.hadoop.HiveUtils;

import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;

@Description(name = "euclid_similarity", value = "_FUNC_(ftvec1, ftvec2) - Returns a euclid distance based similarity"
+ ", which is `1.0 / (1.0 + distance)`, of the given two vectors")
@UDFType(deterministic = true, stateful = false)
public final class EuclidSimilarity extends GenericUDF {

private ListObjectInspector arg0ListOI, arg1ListOI;

@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if(argOIs.length != 2) {
throw new UDFArgumentException("euclid_similarity takes 2 arguments");
}
this.arg0ListOI = HiveUtils.asListOI(argOIs[0]);
this.arg1ListOI = HiveUtils.asListOI(argOIs[1]);

return PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
}

@Override
public FloatWritable evaluate(DeferredObject[] arguments) throws HiveException {
List<String> ftvec1 = HiveUtils.asStringList(arguments[0], arg0ListOI);
List<String> ftvec2 = HiveUtils.asStringList(arguments[1], arg1ListOI);
float d = (float) EuclidDistanceUDF.euclidDistance(ftvec1, ftvec2);
float sim = 1.0f / (1.0f + d);
return new FloatWritable(sim);
}

@Override
public String getDisplayString(String[] children) {
return "euclid_similarity(" + Arrays.toString(children) + ")";
}

}
29 changes: 22 additions & 7 deletions src/test/java/hivemall/knn/distance/EuclidDistanceUDFTest.java
@@ -1,30 +1,45 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.knn.distance;

import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.io.FloatWritable;
import org.junit.Assert;
import org.junit.Test;

public class EuclidDistanceUDFTest {

@Test
public void test1() {
EuclidDistanceUDF udf = new EuclidDistanceUDF();
List<String> ftvec1 = Arrays.asList("1:1.0", "2:2.0", "3:3.0");
List<String> ftvec2 = Arrays.asList("1:2.0", "2:4.0", "3:6.0");
FloatWritable d = udf.evaluate(ftvec1, ftvec2);
Assert.assertEquals((float) Math.sqrt(1.0 + 4.0 + 9.0), d.get(), 0.f);
double d = EuclidDistanceUDF.euclidDistance(ftvec1, ftvec2);
Assert.assertEquals(Math.sqrt(1.0 + 4.0 + 9.0), d, 0.f);
}

@Test
public void test2() {
EuclidDistanceUDF udf = new EuclidDistanceUDF();
List<String> ftvec1 = Arrays.asList("1:1.0", "2:3.0", "3:3.0");
List<String> ftvec2 = Arrays.asList("1:2.0", "3:6.0");
FloatWritable d = udf.evaluate(ftvec1, ftvec2);
Assert.assertEquals((float) Math.sqrt(1.0 + 9.0 + 9.0), d.get(), 0.f);
double d = EuclidDistanceUDF.euclidDistance(ftvec1, ftvec2);
Assert.assertEquals(Math.sqrt(1.0 + 9.0 + 9.0), d, 0.f);
}

}

0 comments on commit 10d6e23

Please sign in to comment.