Skip to content

Commit

Permalink
KGvec2go similarity functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
janothan committed Aug 6, 2020
1 parent 91bb427 commit 70d9cf7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,18 @@ private void shutDown(){
}


/**
* Receive a vector in the form of a double array.
* @param word Word for lookup.
* @param dataset Dataset for lookup.
* @return Null in case of failure, else vector.
*/
public Double[] getVector(String word, KGvec2goDatasets dataset){
// sanity check
if(word == null || dataset == null){
return null;
}

// check buffer
if(isInBuffer(word, dataset)){
return getFromBuffer(word, dataset);
Expand Down Expand Up @@ -144,6 +155,18 @@ public Double[] getVector(String word, KGvec2goDatasets dataset){
}


/**
*
* @param word1 Lookup word 1.
* @param word2 Lookup word 2.
* @param datasets Dataset to be used for lookup.
* @return Similarity.
*/
public Double getSimilarity(String word1, String word2, KGvec2goDatasets datasets){
return cosineSimilarity(getVector(word1, datasets), getVector(word2, datasets));
}


/**
* Look in the cache whether a vector already exists for the given word in the given dataset.
* @param word for which the vector shall be looked up.
Expand Down Expand Up @@ -184,4 +207,29 @@ private static void writeToBuffer(String word, KGvec2goDatasets dataset, Double[
vectorCache.put(key, vector);
}

/**
* Calculates the cosine similarity between two vectors.
* @param vector1 First vector.
* @param vector2 Second vector.
* @return Cosine similarity as double.
*/
public static Double cosineSimilarity(Double[] vector1, Double[] vector2){
if(vector1 == null || vector2 == null){
return null;
}
if(vector1.length != vector2.length){
LOGGER.error("ERROR - the vectors must be of the same dimension.");
throw new ArithmeticException("The vectors must be of the same dimension");
}
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 = norm1 + Math.pow(vector1[i], 2);
norm2 = norm2 + Math.pow(vector2[i], 2);
}
return dotProduct / ( Math.sqrt(norm1) * Math.sqrt(norm2) );
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,59 @@ void getVector() {
result = KGvec2goClient.getInstance().getVector("AABBCCDDEEFF", KGvec2goDatasets.WIKTIONARY);
assertNull(result);
}


@Test
void getSimilarity(){
KGvec2goClient kgvec2go = KGvec2goClient.getInstance();

// ---------
// ALOD
// ---------

assertTrue(kgvec2go.getSimilarity("germany", "europe", KGvec2goDatasets.ALOD) > kgvec2go.getSimilarity("germany", "japan", KGvec2goDatasets.ALOD));
assertNull(kgvec2go.getSimilarity("usa", null, KGvec2goDatasets.ALOD));
assertNull(kgvec2go.getSimilarity(null, "usa", KGvec2goDatasets.ALOD));
assertNull(kgvec2go.getSimilarity("AAABBBCCC", "usa", KGvec2goDatasets.ALOD));

// ----------
// DBpedia
// ----------

assertTrue(kgvec2go.getSimilarity("Germany", "Europe", KGvec2goDatasets.DBPEDIA) > kgvec2go.getSimilarity("Europe", "Japan", KGvec2goDatasets.DBPEDIA));
assertNull(kgvec2go.getSimilarity("USA", null, KGvec2goDatasets.ALOD));
assertNull(kgvec2go.getSimilarity(null, "USA", KGvec2goDatasets.ALOD));
assertNull(kgvec2go.getSimilarity("AAABBBCCC", "USA", KGvec2goDatasets.ALOD));

// ----------
// WordNet
// ----------

assertTrue(kgvec2go.getSimilarity("Germany", "Europe", KGvec2goDatasets.WORDNET) > kgvec2go.getSimilarity("Europe", "Japan", KGvec2goDatasets.WORDNET));
assertNull(kgvec2go.getSimilarity("USA", null, KGvec2goDatasets.WORDNET));
assertNull(kgvec2go.getSimilarity(null, "USA", KGvec2goDatasets.WORDNET));
assertNull(kgvec2go.getSimilarity("AAABBBCCC", "USA", KGvec2goDatasets.WORDNET));


// --------------
// Wiktionary
// --------------

assertTrue(kgvec2go.getSimilarity("Germany", "Europe", KGvec2goDatasets.WIKTIONARY) > kgvec2go.getSimilarity("Europe", "war", KGvec2goDatasets.WIKTIONARY));
assertNull(kgvec2go.getSimilarity("USA", null, KGvec2goDatasets.WIKTIONARY));
assertNull(kgvec2go.getSimilarity(null, "USA", KGvec2goDatasets.WIKTIONARY));
assertNull(kgvec2go.getSimilarity("AAABBBCCC", "USA", KGvec2goDatasets.WIKTIONARY));

}


@Test
void cosineSimilarity(){
Double[] v1 = {3d, 8d, 7d, 5d, 2d, 9d};
Double[] v2 = {10d, 8d, 6d, 6d, 4d, 5d};
Double[] v3 = {10d, 8d};

assertEquals(0.8639, KGvec2goClient.cosineSimilarity(v1, v2), 0.00001);
assertThrows(ArithmeticException.class, () -> KGvec2goClient.cosineSimilarity(v1, v3));
}
}

0 comments on commit 70d9cf7

Please sign in to comment.