Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#6058 TSNE: throw useful exception for no points (distance may be undefined) #6094

Merged
merged 1 commit into from Aug 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -242,6 +242,12 @@ public INDArray computeGaussianPerplexity(final INDArray d, double u) {
tree.search(d.slice(i), k + 1, results, new ArrayList<Double>());
double betas = beta.getDouble(i);

if(results.size() == 0){
throw new IllegalStateException("Search returned no values for vector " + i +
" - similarity \"" + simiarlityFunction + "\" may not be defined (for example, vector is" +
" all zeros with cosine similarity)");
}

INDArray cArr = VPTree.buildFromData(results);
Pair<INDArray, Double> pair = computeGaussianKernel(cArr, beta.getDouble(i), k);
INDArray currP = pair.getFirst();
Expand Down
@@ -0,0 +1,47 @@
package org.deeplearning4j.plot;

import lombok.val;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ArrayList;

import static org.junit.Assert.assertTrue;

public class Test6058 {

@Test
public void test() throws Exception {
//All zero input -> cosine similarity isn't defined
//https://github.com/deeplearning4j/deeplearning4j/issues/6058
val iterations = 10;
val cacheList = new ArrayList<String>();

int nWords = 100;
for(int i=0; i<nWords; i++ ) {
cacheList.add("word_" + i);
}

//STEP 3: build a dual-tree tsne to use later
System.out.println("Build model....");
val tsne = new BarnesHutTsne.Builder()
.setMaxIter(iterations)
.theta(0.5)
.normalize(false)
.learningRate(1000)
.useAdaGrad(false)
//.usePca(false)
.build();

System.out.println("fit");
INDArray weights = Nd4j.rand(new int[]{nWords, 100});
weights.getRow(1).assign(0);
try {
tsne.fit(weights);
} catch (IllegalStateException e){
assertTrue(e.getMessage().contains("may not be defined"));
}
}

}