diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index 5e6d4714d9b1..6997e9e09d7f 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -242,6 +242,12 @@ public INDArray computeGaussianPerplexity(final INDArray d, double u) { tree.search(d.slice(i), k + 1, results, new ArrayList()); double betas = beta.getDouble(i); + if(results.size() == 0){ + throw new IllegalStateException("Search returned no values for vector " + i + + " - similarity \"" + simiarlityFunction + "\" may not be defined (for example, vector is" + + " all zeros with cosine similarity)"); + } + INDArray cArr = VPTree.buildFromData(results); Pair pair = computeGaussianKernel(cArr, beta.getDouble(i), k); INDArray currP = pair.getFirst(); diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java new file mode 100644 index 000000000000..fabb912e0d9b --- /dev/null +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java @@ -0,0 +1,47 @@ +package org.deeplearning4j.plot; + +import lombok.val; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; + +import static org.junit.Assert.assertTrue; + +public class Test6058 { + + @Test + public void test() throws Exception { + //All zero input -> cosine similarity isn't defined + //https://github.com/deeplearning4j/deeplearning4j/issues/6058 + val iterations = 10; + val cacheList = new ArrayList(); + + int nWords = 100; + for(int i=0; i