Skip to content

Commit

Permalink
Add a new method for standard DotProduct for users seeking non-normal…
Browse files Browse the repository at this point in the history
…ized score (#876)

Co-authored-by: Shashank Paliwal <spaliwal@spaliwal-mn1.linkedin.biz>
  • Loading branch information
shashankiiit and Shashank Paliwal committed Dec 5, 2022
1 parent 90a6b6d commit e9495d4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,31 @@ public static Float cosineSimilarity(Object obj1, Object obj2) {
}
}

/**
* Returns a standard dotProduct of two vector objects.
* Use {@link MvelContextUDFs#cosineSimilarity(Object, Object)} for normalized dot-product.
*/
@ExportToMvel
public static Double dotProduct(Object obj1, Object obj2) {
if (obj1 == null || obj2 == null) {
return null;
}
Map<String, Float> mapA = CoercionUtils.coerceToVector(obj1);
Map<String, Float> mapB = CoercionUtils.coerceToVector(obj2);
double dotProduct = 0;

for (Map.Entry<String, Float> entry : mapA.entrySet()) {
String k = entry.getKey();
float valA = entry.getValue();
Float valB = mapB.get(k);
if (valB != null) {
dotProduct += ((double) valA * valB);
}
}

return dotProduct;
}

/**
* convert input to lower case string
* @param input input string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,24 @@ public void testCosineSimilarity() {
categoricalOutput2.clear();
assertEquals(cosineSimilarity(categoricalOutput1, categoricalOutput2), 0.0F);
}

@Test
public void testDotProduct() {
// Test basic dot product calculation
Map<String, Float> categoricalOutput1 = new HashMap<>();
categoricalOutput1.put("A", 1F);
categoricalOutput1.put("B", 1F);

Map<String, Float> categoricalOutput2 = new HashMap<>();
categoricalOutput2.put("B", 1F);
categoricalOutput2.put("C", 1F);

assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 1.0D);

// Test dot product of zero vectors
categoricalOutput1.clear();
assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 0.0D);
categoricalOutput2.clear();
assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 0.0D);
}
}

0 comments on commit e9495d4

Please sign in to comment.